Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c9ca102
Move apply_w8a8_block_fp8_linear to an op class
ElizaWszola Sep 11, 2025
eef4349
Remove TODO, bring back old one
ElizaWszola Sep 11, 2025
dd53183
CUDA graphs fix
ElizaWszola Sep 11, 2025
bb24881
Clean up
ElizaWszola Sep 11, 2025
1ba47cd
Create linear op objects conditionally, move some arch checks to bloc…
ElizaWszola Sep 11, 2025
02793b9
format
ElizaWszola Sep 11, 2025
b72c9f2
clean up repetitive code
ElizaWszola Sep 12, 2025
d51f35c
More aggressive dispatch of blockscale ops
ElizaWszola Sep 12, 2025
a6ae689
fix
ElizaWszola Sep 12, 2025
3238ff6
Deep_gemm fix
ElizaWszola Sep 12, 2025
f9c79aa
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 12, 2025
23341c2
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 12, 2025
9b09b60
Post-merge fixes, better dispatch
ElizaWszola Sep 12, 2025
e6b0028
small fixes
ElizaWszola Sep 12, 2025
9b5c552
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 15, 2025
ef6f1e2
Fix cutlass compilation issue on Hopper
ElizaWszola Sep 17, 2025
77335de
Cleanup bad transpose
ElizaWszola Sep 17, 2025
5eaf155
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 17, 2025
e036dac
Wrap w8a8_block_fp8_matmul
ElizaWszola Sep 17, 2025
233e874
Rename padded_cutlass to padded_cutlass_scaled_mm, add todo
ElizaWszola Sep 17, 2025
1edfedc
Cleanup dispatch_w8a8_blockscale_func
ElizaWszola Sep 17, 2025
35a0236
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 18, 2025
0ac3a1e
Deep gemm warmup fix
ElizaWszola Sep 18, 2025
9a48100
Fix deep gemm support function
ElizaWszola Sep 18, 2025
b6a8fb8
Feedback
ElizaWszola Sep 19, 2025
e89ecd8
Pre-commit fixes
ElizaWszola Sep 19, 2025
00cb05c
Pre-commit fixes 2
ElizaWszola Sep 19, 2025
66c89e6
Feedback
ElizaWszola Sep 19, 2025
d9b4121
fix type issue
ElizaWszola Sep 19, 2025
1bc81a1
Add use_ue8m0 support to _quantize_group_native
ElizaWszola Sep 19, 2025
ec73268
Fix padding compilation issue
ElizaWszola Sep 22, 2025
d19bf4b
Feedback
ElizaWszola Sep 22, 2025
1f895e9
Update vllm/model_executor/layers/quantization/utils/fp8_utils.py
ElizaWszola Sep 22, 2025
be3ac58
Link bad group shape issue
ElizaWszola Sep 22, 2025
3772f2f
format
ElizaWszola Sep 22, 2025
8b6cbe4
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 22, 2025
2a87a3b
fix quant config condition
ElizaWszola Sep 22, 2025
012eaff
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
mgoin Sep 22, 2025
e7f6ec9
fix quant issue (TODO test)
ProExpertProg Sep 22, 2025
10829d3
fix custom op test
ProExpertProg Sep 22, 2025
15cf30e
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 23, 2025
ebdcb10
CUDA condition for compressed tensors and H100
ElizaWszola Sep 23, 2025
2e3d206
Fix quantfp8 test
ElizaWszola Sep 23, 2025
bd32cb9
Test scales_col vs. scales_native
ElizaWszola Sep 23, 2025
efa4446
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ElizaWszola Sep 23, 2025
1f00804
Add compressed tensors model test
ElizaWszola Sep 23, 2025
e895df6
Extra asserts, don't use enabled()
ElizaWszola Sep 23, 2025
9806cf8
CUDA path for quant
ProExpertProg Sep 23, 2025
2ae1ef9
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ProExpertProg Sep 23, 2025
00bd638
Merge branch 'main' into move-apply_w8a8_block_fp8_linear-to-class
ProExpertProg Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul,
w8a8_triton_block_scaled_mm,
)
from vllm.utils import FlexibleArgumentParser, cdiv

Expand Down Expand Up @@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
w8a8_triton_block_scaled_mm,
)
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
Expand Down Expand Up @@ -59,7 +59,7 @@ def deepgemm_gemm():

# === vLLM Triton Implementation ===
def vllm_triton_gemm():
return w8a8_block_fp8_matmul(A_vllm,
return w8a8_triton_block_scaled_mm(A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/quantization/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
Expand Down Expand Up @@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):

ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)

rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
Expand Down
26 changes: 19 additions & 7 deletions tests/kernels/quantization/test_fp8_quant_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
group_size: int, seed: int,
use_ue8m0: bool) -> None:
"""Test QuantFP8 group quantization with various configurations.

Tests both CUDA and native implementations, column-major scales,
Expand All @@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
column_major_scales=False,
use_ue8m0=use_ue8m0)

# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
Expand All @@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size)
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size

# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)

# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
Expand All @@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,


@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None:
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
current_platform.seed_everything(seed)

group_size = 64
Expand All @@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
column_major_scales=False,
use_ue8m0=use_ue8m0)

x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
Expand All @@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

Expand Down
30 changes: 0 additions & 30 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform

RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
Expand Down Expand Up @@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()


@pytest.mark.skipif(
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_gemm_w8a8_blockscale: str,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
use_rocm_aiter_gemm_w8a8_blockscale)

use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
int(use_rocm_aiter_gemm_w8a8_blockscale)))
block_scale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_gemm_w8a8_blockscale):
assert block_scale_func == (
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
else:
assert block_scale_func == w8a8_block_fp8_matmul


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
Expand Down
35 changes: 35 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity)
assert perplexity <= exp_perplexity


def test_compressed_tensors_fp8_block_enabled(vllm_runner):
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
with vllm_runner(model_path) as llm:

fp8_dtype = current_platform.fp8_dtype()

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp)

assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2

input_quant_op = \
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method == input_quant_op.forward_cuda

llm.apply_model(check_model)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
17 changes: 17 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,23 @@ def __post_init__(self):
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False

# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
# If no matches, return None
return None

def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
if (weight_quant is not None
and weight_quant.strategy == QuantizationStrategy.BLOCK):
return True
return False

@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
Expand Down Expand Up @@ -41,16 +41,30 @@ def __init__(self, weight_quant: QuantizationArgs,
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)

self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN

self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()

if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)

@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
Expand Down Expand Up @@ -141,13 +155,14 @@ def apply_weights(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if layer.weight_block_size is not None:
return apply_fp8_block_linear(
layer,
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
)

return self.fp8_linear.apply(input=x,
weight=layer.weight,
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/quantization/deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs(
return M, N, K, C


def w8a8_block_fp8_matmul_deepgemm(
def w8a8_deepgemm_block_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Expand All @@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm(
return C


def w8a8_block_fp8_matmul_deepgemm_fake(
def w8a8_deepgemm_block_scaled_mm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Expand All @@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake(


direct_register_custom_op(
op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_block_fp8_matmul_deepgemm,
op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[],
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)
Loading