Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 154 additions & 38 deletions benchmarks/kernels/bench_per_token_quant_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable
from unittest.mock import patch

import pandas as pd
import torch

from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


def with_triton_mode(fn):
"""Temporarily force the Triton fallback path"""

def wrapped(*args, **kwargs):
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
return fn(*args, **kwargs)

return wrapped


# TODO(luka): use standalone_compile utility
Expand All @@ -21,78 +32,183 @@ def inner(*args):
return inner


torch._dynamo.config.recompile_limit = 8888
compilation_config = CompilationConfig(custom_ops=["none"])
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)
def bench_compile(fn: Callable):
# recompile for different shapes
fwd = torch.compile(fn, fullgraph=True, dynamic=False)

# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
return with_dyn_arg(fwd, 0, 0)


def cuda_per_token_quant_fp8(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)
torch._dynamo.config.recompile_limit = 8888


def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between Triton and CUDA implementations."""
def calculate_diff(
batch_size: int,
hidden_size: int,
group_shape: GroupShape,
dtype: torch.dtype,
):
"""Calculate the difference between Inductor and CUDA implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)

quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)

torch_out, torch_scale = torch_per_token_quant_fp8(x)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)

if torch.allclose(
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
out_allclose = lambda o1, o2: torch.allclose(
o1.to(torch.float32),
o2.to(torch.float32),
rtol=1e-3,
atol=1e-5,
)
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)

if (
out_allclose(cuda_out, torch_out)
and scale_allclose(cuda_scale, torch_scale)
and out_allclose(cuda_out, torch_eager_out)
and scale_allclose(cuda_scale, torch_eager_scale)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")


batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
batch_sizes = [1, 16, 32, 64, 128]
group_shapes = [
GroupShape.PER_TENSOR,
GroupShape.PER_TOKEN,
GroupShape(1, 64),
GroupShape(1, 128),
]
column_major_scales = [True, False]

config_gen = itertools.product(
group_shapes,
column_major_scales,
batch_sizes,
hidden_sizes,
)

configs = list(itertools.product(batch_size_range, seq_len_range))
# filter out column-major scales for non-group, reverse order
configs = list(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda"],
line_names=["Torch", "CUDA"],
styles=[("blue", "-"), ("green", "-")],
line_vals=["torch", "cuda", "triton"],
line_names=["Torch (Compiled)", "CUDA", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
plot_name="QuantFP8 performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
def benchmark_quantization(
batch_size,
hidden_size,
provider,
group_shape: GroupShape,
col_major: bool,
dtype: torch.dtype,
):
device = torch.device("cuda")

x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)

if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone())
fn = lambda: quant_fp8.forward_cuda(x.clone())
elif provider == "triton":
if not group_shape.is_per_group():
# Triton only supported for per-group
return 0, 0, 0

fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())

ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms


# TODO(luka) extract to utils
def compute_geomean_speedups(
df: pd.DataFrame,
baseline_col: str,
speedup_cols: list[str],
groupby_cols: list[str] | None = None,
) -> pd.DataFrame:
"""
Compute geometric mean speedups over a baseline column.

Args:
df: Input dataframe
baseline_col: Column to use as baseline
speedup_cols: Columns to compute speedups for
groupby_cols: Columns to group by. If None, compute over entire df.

Returns:
pd.DataFrame with geometric mean speedups
"""
from scipy.stats import gmean

def geo_speedup(group: pd.DataFrame) -> pd.Series:
ratios = {
col: (group[baseline_col] / group[col]).values for col in speedup_cols
}
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})

if groupby_cols is None:
result = geo_speedup(df).to_frame().T
else:
result = df.groupby(groupby_cols).apply(geo_speedup).reset_index()

return result


if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)
parser = FlexibleArgumentParser(
description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
)
parser.add_argument("-c", "--check", action="store_true")
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
)

args = parser.parse_args()
assert args

dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]

if args.check:
for group_shape in group_shapes:
group_size = group_shape[1]
print(f"{group_size=}")
calculate_diff(
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
)

df = benchmark_quantization.run(print_data=True, dtype=dtype, return_df=True)

# Print geomean speedups
geo_table_grouped = compute_geomean_speedups(
df,
baseline_col="Torch (Compiled)",
speedup_cols=["CUDA", "Triton"],
groupby_cols=["col_major", "group_shape"],
)

print("Speedup over Torch (Compiled)")
print(geo_table_grouped.to_string(index=False))
148 changes: 148 additions & 0 deletions benchmarks/kernels/benchmark_quantfp8_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark for QuantFP8 Group Quantization implementation."""

import argparse

import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform


def _time_cuda(
fn,
warmup_iters: int,
bench_iters: int,
) -> float:
# warmup
for _ in range(warmup_iters):
fn()
torch.cuda.synchronize()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
for _ in range(bench_iters):
fn()
end.record()
torch.cuda.synchronize()

return start.elapsed_time(end) / bench_iters # ms/iter


def run_benchmark(
shape: tuple[int, int],
group_size: int,
column_major: bool,
warmup_iters: int,
bench_iters: int,
) -> None:
"""Benchmark QuantFP8 with group quantization using different backends."""
num_tokens, hidden_dim = shape

device = torch.device("cuda")
torch.manual_seed(42)
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8

group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(
static=False, group_shape=group_shape, column_major_scales=column_major
)

def cuda_impl():
return quant_op.forward_cuda(x.clone())

def native_impl():
return quant_op.forward_native(x.clone())

cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
native_ms = _time_cuda(native_impl, warmup_iters, bench_iters)

speedup = cuda_ms / native_ms if native_ms else 0

cfg_desc = f"shape={shape} gs={group_size:<3} col_major={column_major}"
print(f"{cfg_desc:45} | {cuda_ms:7.3f} | {native_ms:7.3f} | {speedup:6.2f}x")


def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark QuantFP8 group quantization implementation"
)
parser.add_argument(
"--warmup-iters", type=int, default=10, help="Number of warmup iterations"
)
parser.add_argument(
"--bench-iters", type=int, default=100, help="Number of benchmark iterations"
)
parser.add_argument(
"--shapes",
type=str,
default="32,128;64,256;16,512;128,1024;256,2048",
help="Shapes to benchmark as 'tokens,hidden;...' (default: multiple shapes)",
)
parser.add_argument(
"--group-sizes",
type=str,
default="64,128",
help="Group sizes to benchmark (comma-separated)",
)
parser.add_argument(
"--no-column-major",
action="store_true",
help="Skip column-major scale benchmarks",
)
return parser.parse_args()


def main():
if not current_platform.is_cuda():
raise RuntimeError("CUDA device is required to run this benchmark.")

args = parse_args()

shapes = []
for shape_str in args.shapes.split(";"):
tokens, hidden = map(int, shape_str.split(","))
shapes.append((tokens, hidden))

group_sizes = list(map(int, args.group_sizes.split(",")))

print("\n" + "=" * 80)
print("QuantFP8 Group Quantization Benchmark (CUDA kernel vs PyTorch native)")
print("=" * 80)
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Warmup iterations: {args.warmup_iters}")
print(f"Benchmark iterations: {args.bench_iters}")
print("=" * 80)

print(f"{'Configuration':45} | {'CUDA':^9} | {'Native':^9} | {'Speedup':^8}")
print("-" * 80)

for shape in shapes:
for gs in group_sizes:
run_benchmark(
shape,
gs,
column_major=False,
warmup_iters=args.warmup_iters,
bench_iters=args.bench_iters,
)

if not args.no_column_major:
run_benchmark(
shape,
gs,
column_major=True,
warmup_iters=args.warmup_iters,
bench_iters=args.bench_iters,
)

print("=" * 80)


if __name__ == "__main__":
main()
Loading