Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ aten/build/
aten/src/ATen/Config.h
aten/src/ATen/cuda/CUDAConfig.h
benchmarks/.data
benchmarks/data
caffe2/cpp_test/
dist/
docs/build/
Expand Down
54 changes: 33 additions & 21 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import torch
import torch.nn as nn
import tqdm
from tabulate import tabulate
from torch.profiler import ProfilerActivity, profile
from utils import (
get_gpu_kernel_gemm_time_s,
Expand Down Expand Up @@ -77,8 +78,11 @@ def get_gemm_times(
K: int,
N: int,
fast_accum: bool,
float8_recipe_name: Optional[str],
recipe_name: Optional[str],
):
assert recipe_name in {"rowwise"}, (
"Only support real benchmarks for 'rowwise' recipe for now"
)
device = torch.device("cuda")

# bf16 time
Expand All @@ -100,7 +104,7 @@ def get_gemm_times(
.contiguous()
.t()
)
if float8_recipe_name in ("rowwise"):
if recipe_name == "rowwise":
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
else:
Expand All @@ -118,41 +122,47 @@ def do_matmul(A, B):

def run(
outfile: str,
recipe_name: str,
do_benchmarks: bool = True,
shape_gen_name: str = "pow2",
n_limit: Optional[int] = None,
float8_recipe_name: Optional[str] = None,
):
"""
Args:
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
"""

assert float8_recipe_name is not None, "unsupported"

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"torch version: {torch.__version__}")
print(f"torchao version: {torchao.__version__}")
print(f"do_benchmarks: {do_benchmarks}")
print(f"shape_gen_name: {shape_gen_name}")
print(f"float8_recipe_name: {float8_recipe_name}")
config_table = [
["GPU", torch.cuda.get_device_name(0)],
["torch version", torch.__version__],
["torchao version", torchao.__version__],
["recipe_name", recipe_name],
["do_benchmarks", do_benchmarks],
["shape_gen_name", shape_gen_name],
]
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))

M, K, N = sympy.symbols("M K N")

fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
M,
K,
N,
float8_recipe_name,
)
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(
M, K, N, torch.bfloat16, None, None
)
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
M, K, N, torch.float8_e4m3fn, float8_recipe_name, None
recipe_name,
)
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None)

if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")):
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
M, K, N, torch.float4_e2m1fn_x2, recipe_name
)
else:
gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
M, K, N, torch.float8_e4m3fn, gemm_recipe_name
)
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
Expand Down Expand Up @@ -219,7 +229,7 @@ def run(
K_val,
N_val,
True,
float8_recipe_name,
recipe_name,
)
b_bf16_gemm_time_s = bf16_g1
b_fp8_gemm_time_s = f8_g1
Expand Down Expand Up @@ -261,6 +271,8 @@ def run(
m_fp8_dyn = torch.compile(m_fp8_dyn)
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)

r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)

results.append(
[
M_val,
Expand All @@ -273,7 +285,7 @@ def run(
r_fp8_ovhd_time_s,
# roofline - gemm + overhead, and speedup
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s),
r_speedup,
# benchmarks - gemm
b_bf16_gemm_time_s,
b_fp8_gemm_time_s,
Expand Down
1 change: 1 addition & 0 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def run(
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
* `mx_recipe_name (optional)`: MX format recipe
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
"""

Expand Down
89 changes: 64 additions & 25 deletions torchao/testing/training/roofline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BYTES_PER_EL_FLOAT4 = 0.5
BYTES_PER_EL_FLOAT8 = 1
BYTES_PER_EL_BF16 = 2
BYTES_PER_EL_FLOAT32 = 4

gpu_name_to_specs = {
"NVIDIA H100": {
Expand Down Expand Up @@ -228,7 +229,7 @@ def get_individual_gemm_time_sympy(
K: sympy.Symbol,
N: sympy.Symbol,
dtype,
mx_recipe_name,
mx_recipe_name: Optional[str],
gpu_name: Optional[str] = None,
) -> sympy.Symbol:
# compute bound
Expand All @@ -241,27 +242,24 @@ def get_individual_gemm_time_sympy(
elif dtype is torch.float4_e2m1fn_x2:
peak_tops = specs["fp4_peak_tops"]
else:
assert False, "unsupported"
assert False, f"unsupported dtype: {dtype}"
compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]

# memory bound
num_reads = M * K + K * N
num_writes = M * N

if mx_recipe_name is not None:
assert mx_recipe_name in (
"mxfp8_emulated",
"mxfp8_cublas",
"mxfp8_cublas_rceil",
"mxfp4_cutlass",
), "unsupported"
assert mx_recipe_name.startswith(("mxfp8", "mxfp4", "nvfp4")), (
f"Unsupported recipe {mx_recipe_name}"
)
assert dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float4_e2m1fn_x2,
), "unsupported"
# adjust reads for MX scaling
block_size = 32
block_size = 32 if mx_recipe_name.startswith("mx") else 16
num_scale_reads = num_reads // block_size
# note: e8m0 bytes per element is the same as for e4m3|e5m2
num_reads = num_reads + num_scale_reads
Expand All @@ -274,7 +272,7 @@ def get_individual_gemm_time_sympy(
elif dtype is torch.float4_e2m1fn_x2:
bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
else:
assert False, "unsupported"
assert False, f"unsupported dtype: {dtype}"
mem_gemm_time_s = (
bytes_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
)
Expand Down Expand Up @@ -375,28 +373,68 @@ def get_inference_tensor_memory_traffic_ovhd_s(
dim0,
dim1,
tensor_role: str,
float8_recipe_name: Optional[str],
recipe_name: Optional[str],
fuse_with_prev=False,
) -> List[Union[sympy.Symbol, float]]:
"""
Inference version of `get_tensor_memory_traffic_ovhd_s`.
The only thing happening here is we quantize the activation.
"""
assert float8_recipe_name == "rowwise", "unsupported"
assert fuse_with_prev is False, "unsupported"
assert tensor_role == "input", "inference only quantizes input activations"

# assumes input bf16, output f8
numel = dim0 * dim1

res_bytes = None

assert tensor_role == "input"
# x_bf16 = ...
# kernel 1: x_bf16 -> x_fp8
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
res_bytes = [
kernel_1_rw,
]
allowed_recipes = {"tensorwise", "rowwise", "mxfp8*", "nvfp4*", "mxfp4*"}

match recipe_name:
case "tensorwise":
# x_bf16 = ...
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
# kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
kernel_1_rw = BYTES_PER_EL_BF16 * numel
# kernel 3: read in bf16, write in float8
kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
res_bytes = [kernel_1_rw, kernel_3_rw]

case "rowwise":
# x_bf16 = ...
# kernel 1: x_bf16 -> x_fp8 (with per-row scaling)
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
# add in the bytes for scale writes
kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
res_bytes = [kernel_1_rw]

case name if name and name.startswith("mxfp8"):
# x_bf16 = ...
# kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference)
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
# add in the bytes for scale writes in E8M0 format
kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // 32)
res_bytes = [kernel_1_rw]

case name if name and (name.startswith("mxfp4") or name.startswith("nvfp4")):
# For NVFP4, assume minimal overhead since it's primarily a compute format
# x_bf16 = ...
# kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference)
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
if name.startswith("nvfp4"):
kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor
# add in the bytes for scale writes in E4M3 | E8M0
block_size = 32 if name.startswith("mxfp4") else 16
kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // block_size)
res_bytes = [kernel_1_rw]

case _:
raise ValueError(
f"Unknown recipe name: {recipe_name}. "
f"Allowed recipes: {allowed_recipes}"
)

# convert from bytes to seconds
res_s = [
Expand All @@ -414,7 +452,7 @@ def get_inference_float8_mem_sympy(
M,
K,
N,
float8_recipe_name: Optional[str],
recipe_name: Optional[str],
gpu_name: Optional[str] = None,
):
specs = get_specs(gpu_name)
Expand All @@ -425,7 +463,7 @@ def get_inference_float8_mem_sympy(
M,
K,
tensor_role="input",
float8_recipe_name=float8_recipe_name,
recipe_name=recipe_name,
fuse_with_prev=False,
)
res = sum([*fwd_fp8_input_mem])
Expand All @@ -437,11 +475,12 @@ def get_inference_gemm_time_sympy(
K: sympy.Symbol,
N: sympy.Symbol,
dtype,
float8_recipe_name: Optional[str],
gpu_name: Optional[str],
recipe_name: Optional[str],
gpu_name: Optional[str] = None,
):
assert float8_recipe_name == "rowwise" or float8_recipe_name is None, "unsupported"
# note: this function is currently not super accurate for small shapes:
# when M,K,N <= 1k,1k,1k it undercounts by around 2x
gemm_output_time_s = get_individual_gemm_time_sympy(M, K, N, dtype, None, gpu_name)
gemm_output_time_s = get_individual_gemm_time_sympy(
M, K, N, dtype, recipe_name, gpu_name
)
return gemm_output_time_s
Loading