diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index a5a5b52f6039..02f8c593392c 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -17,7 +17,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.utils import FlexibleArgumentParser, cdiv @@ -158,7 +158,7 @@ def bench_fp8( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) ), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) ), "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index b99c2099f2c3..b3c3742825de 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.triton_utils import triton from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 @@ -59,7 +59,7 @@ def deepgemm_gemm(): # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_block_fp8_matmul(A_vllm, + return w8a8_triton_block_scaled_mm(A_vllm, B_vllm, A_scale_vllm, B_scale_vllm, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c440747316b8..c0b934fc55ae 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, get_col_major_tma_aligned_tensor, - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 @@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 720eee62760d..3d4c851a9b88 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -20,9 +20,11 @@ (8, 513, 64), # Non-divisible (native only) ]) @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, - group_size: int, seed: int) -> None: + group_size: int, seed: int, + use_ue8m0: bool) -> None: """Test QuantFP8 group quantization with various configurations. Tests both CUDA and native implementations, column-major scales, @@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) # 1. Test native implementation (always available) x_quant_native, scales_native = quant_op.forward_native(x.clone()) @@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, # 2. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x.clone()) - assert scales_col.shape == (expected_num_groups, batch_size) + assert scales_col.shape == (batch_size, expected_num_groups) + assert scales_col.stride(0) == 1 + assert scales_col.stride(1) == batch_size + + # Test column-major scales consistency + assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() -def test_quantfp8_group_multidimensional(seed: int) -> None: +def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: current_platform.seed_everything(seed) group_size = 64 @@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) x_quant, scales = quant_op.forward_native(x_3d.clone()) assert x_quant.shape == x_3d.shape @@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: # Test column_major_scales with multi-dim quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x_3d.clone()) assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 92ce10a9efc0..200b6ecd5852 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -17,8 +17,6 @@ from vllm.model_executor.layers.layernorm import (RMSNorm, dispatch_rocm_rmsnorm_func, fused_add_rms_norm, rms_norm) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.skipif( - not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), - reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") -@pytest.mark.parametrize("use_cutlass", [True, False]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) -def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, - use_rocm_aiter_gemm_w8a8_blockscale: str, - monkeypatch): - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", - use_rocm_aiter_gemm_w8a8_blockscale) - - use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( - int(use_rocm_aiter_gemm_w8a8_blockscale))) - block_scale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) - if use_cutlass: - assert block_scale_func == cutlass_scaled_mm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == ( - torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) - else: - assert block_scale_func == w8a8_block_fp8_matmul - - @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c0ab3fbb1062..af8c7ec3b482 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -18,6 +18,9 @@ CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) assert perplexity <= exp_perplexity + + +def test_compressed_tensors_fp8_block_enabled(vllm_runner): + model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" + with vllm_runner(model_path) as llm: + + fp8_dtype = current_platform.fp8_dtype() + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) + assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp) + + assert qkv_proj.weight.dtype is fp8_dtype + assert qkv_proj.weight_scale.dtype is torch.float32 + assert len(qkv_proj.weight.shape) == 2 + assert len(qkv_proj.weight_scale.shape) == 2 + + input_quant_op = \ + qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + assert isinstance(input_quant_op, QuantFP8) + assert input_quant_op._forward_method == input_quant_op.forward_cuda + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index a2562a10b45a..50e8cad23617 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -687,6 +687,23 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True + def has_blocked_weights(): + if self.quant_config is not None: + if hasattr(self.quant_config, "weight_block_size"): + return self.quant_config.weight_block_size is not None + elif hasattr(self.quant_config, "has_blocked_weights"): + return self.quant_config.has_blocked_weights() + return False + + # Enable quant_fp8 CUDA ops (TODO disable in follow up) + # On H100 the CUDA kernel is faster than + # native implementation + # https://github.com/vllm-project/vllm/issues/25094 + if has_blocked_weights(): + custom_ops = self.compilation_config.custom_ops + if "none" not in custom_ops and "-quant_fp8" not in custom_ops: + custom_ops.append("+quant_fp8") + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d6550dd16892..3f771ea2abd1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -644,6 +644,14 @@ def get_cache_scale(self, name: str) -> Optional[str]: # If no matches, return None return None + def has_blocked_weights(self) -> bool: + for scheme in self.target_scheme_map.values(): + weight_quant = scheme.get("weights") + if (weight_quant is not None + and weight_quant.strategy == QuantizationStrategy.BLOCK): + return True + return False + @staticmethod def supports_cutlass_24( weight_quant: Optional[QuantizationArgs], diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index d42ae22c5139..fa0816959fcd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_block_linear, check_aiter_fp8_linear_support, + W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, @@ -41,16 +41,30 @@ def __init__(self, weight_quant: QuantizationArgs, self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape) self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + else: + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + if self.weight_block_size is not None: + assert not self.is_static_input_scheme + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) + @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -141,13 +155,14 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if layer.weight_block_size is not None: - return apply_fp8_block_linear( - layer, + if self.weight_block_size is not None: + return self.w8a8_block_fp8_linear.apply( input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported) + ) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index d26a932eddb2..c2b3ccf19fca 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs( return M, N, K, C -def w8a8_block_fp8_matmul_deepgemm( +def w8a8_deepgemm_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm( return C -def w8a8_block_fp8_matmul_deepgemm_fake( +def w8a8_deepgemm_block_scaled_mm_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake( direct_register_custom_op( - op_name="w8a8_block_fp8_matmul_deepgemm", - op_func=w8a8_block_fp8_matmul_deepgemm, + op_name="w8a8_deepgemm_block_scaled_mm", + op_func=w8a8_deepgemm_block_scaled_mm, mutates_args=[], - fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, + fake_impl=w8a8_deepgemm_block_scaled_mm_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2b24e052053c..c4951712baa7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -31,7 +31,7 @@ register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_block_linear, check_aiter_fp8_linear_support, + W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, expert_weight_is_col_major, get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block, @@ -234,15 +234,28 @@ def __init__(self, quant_config: Fp8Config): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN + if self.weight_block_size: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: - self.act_q_group_shape = GroupShape.PER_TENSOR + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) + if self.block_quant: + assert not self.act_q_static + assert self.weight_block_size is not None + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -391,12 +404,15 @@ def apply(self, bias=bias) if self.block_quant: - return apply_fp8_block_linear( - layer, + assert self.weight_block_size is not None + + return self.w8a8_block_fp8_linear.apply( input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported) + ) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 31182f40b48f..ece3e5817116 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -27,11 +27,14 @@ class QuantFP8(CustomOp): This CustomOp supports both static and dynamic quantization. """ - def __init__(self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None, - column_major_scales: bool = False): + def __init__( + self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False, + use_ue8m0: Optional[bool] = None, # for Torch compile + ): """ :param static: static or dynamic quantization :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, @@ -46,6 +49,7 @@ def __init__(self, self.group_shape = group_shape self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales + self.use_ue8m0 = use_ue8m0 self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: @@ -70,7 +74,8 @@ def forward_cuda( x, group_size=self.group_size, column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE) + dtype=_FP8_DTYPE, + use_ue8m0=self.use_ue8m0) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape @@ -137,7 +142,10 @@ def _quantize_group_native( x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + scales_raw = absmax / _FP8_MAX + if self.use_ue8m0: + scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw))) + scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) x_scaled = x_grouped / scales x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) @@ -151,6 +159,6 @@ def _quantize_group_native( scales = scales.reshape(orig_shape[:-1] + (num_groups, )) if self.column_major_scales: - scales = scales.transpose(-2, -1).contiguous() + scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d1d87b7ba12e..2098086bf240 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,8 +13,9 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + GroupShape, group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.model_executor.parameter import (BlockQuantScaleParameter, @@ -24,6 +25,7 @@ from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +# We need to pass in the is_hopper flag as argument because the function +# current_platform.is_device_capability() is not supported by Torch compiler. def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -42,15 +46,17 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, + is_hopper: Optional[bool] = None, ) -> torch.Tensor: + if is_hopper is None: + is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None - and current_platform.is_device_capability(90) else Bs.T) + scale_b=Bs if block_size is not None and is_hopper else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -98,122 +104,189 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def dispatch_w8a8_blockscale_func( - use_cutlass: bool, use_aiter_and_is_supported: bool -) -> Callable[[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - list[int], - torch.dtype, -], torch.Tensor]: - if use_cutlass: - return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale - return w8a8_block_fp8_matmul +# TODO we should be able to change the type of block_size to GroupShape +# after we resolve GroupShape compilation issue +# https://github.com/vllm-project/vllm/issues/25270 +def _w8a8_triton_block_scaled_mm_func( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, + block_size, output_dtype) -# TODO fix ROCm->Triton custom path: -# https://github.com/vllm-project/vllm/issues/14397 -def apply_w8a8_block_fp8_linear( - input: torch.Tensor, +def _w8a8_triton_block_scaled_mm_fake( + qx: torch.Tensor, weight: torch.Tensor, - block_size: list[int], + x_scale: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, + block_size: list[int], + output_dtype: torch.dtype, ) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype + return torch.empty((qx.size(0), weight.size(0)), + dtype=output_dtype, + device=qx.device) + - if should_use_deepgemm_for_fp8_linear(output_dtype, weight): +direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + mutates_args=[], + fake_impl=_w8a8_triton_block_scaled_mm_fake, + dispatch_key="CUDA", +) + +# TODO fix ROCm->Triton custom path: +# https://github.com/vllm-project/vllm/issues/14397 +class W8A8BlockFp8LinearOp: + """ + This class executes a Blocked FP8 linear layer using cutlass if supported + and torch.scaled_mm otherwise. + """ + + def __init__( + self, + weight_group_shape: GroupShape, + act_quant_group_shape: GroupShape, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, + ): + self.weight_group_shape = weight_group_shape + self.act_quant_group_shape = act_quant_group_shape + self.is_deep_gemm_supported = is_deep_gemm_supported() + self.is_hopper = current_platform.is_device_capability(90) + + # Get the correct blockscale mul and input quant operations. + # We can't use _dispatch_w8a8_blockscale_op to figure out if we want + # to use deepgemm because we don't know the shape of weights (and + # whether deepgemm supports it) at the init time. + self.w8a8_blockscale_op, self.input_quant_op = \ + self._dispatch_w8a8_blockscale_op( + cutlass_block_fp8_supported, use_aiter_and_is_supported) + self.deepgemm_input_quant_op = (QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported + else None) + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - ) + if should_use_deepgemm_for_fp8_linear(output_dtype, weight, + self.is_deep_gemm_supported): + output = self._run_deepgemm(input, weight, weight_scale) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + def _run_deepgemm( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 - output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + assert self.deepgemm_input_quant_op is not None + q_input, x_scale = self.deepgemm_input_quant_op(input_2d) + return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm( q_input, weight, x_scale, weight_scale, - block_size, - output_dtype=output_dtype) - if bias is not None: - output += bias - return output.to(dtype=output_dtype).view(*output_shape) - - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - cutlass_block_fp8_supported, use_aiter_and_is_supported) - if cutlass_block_fp8_supported: - num_pad = 0 - if current_platform.is_device_capability(90): - # pad first dimension to be divisible by 4 due to - # cutlass blockwise gemm limitation for hopper - num_pad = 4 - (input_2d.shape[0] % 4) - if num_pad > 0: - input_2d = torch.nn.functional.pad(input_2d, - (0, 0, 0, num_pad), - "constant", 0) - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - if num_pad > 0: - output = output[:-num_pad] - else: - if use_aiter_and_is_supported: - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False) - - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) - - -def apply_w8a8_block_fp8_linear_fake( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - + self.weight_group_shape, + output_dtype=input_2d.dtype) -if not current_platform.is_cpu(): - direct_register_custom_op( - op_name="apply_w8a8_block_fp8_linear", - op_func=apply_w8a8_block_fp8_linear, - mutates_args=[], - fake_impl=apply_w8a8_block_fp8_linear_fake, - ) + def _run_cutlass( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + if self.is_hopper: + # We pad unconditionally (even if shape is already divisible by 4) + # to support dynamic shape for input_2d.shape[0] in torch.compile + x = torch.nn.functional.pad(input_2d, + (0, 0, 0, -input_2d.shape[0] % 4)) + else: + x = input_2d + + q_input, x_scale = self.input_quant_op(x) + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + list(self.weight_group_shape), + input_2d.dtype, self.is_hopper) + output = output[0:input_2d.shape[0], ...] + return output + + def _run_aiter( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.act_quant_group_shape == GroupShape(1, 128) + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + q_input, weight, x_scale, weight_scale, self.weight_group_shape, + input_2d.dtype) + + def _run_triton( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + q_input, x_scale = self.input_quant_op(input_2d) + return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( + q_input, weight, x_scale, weight_scale, self.weight_group_shape, + input_2d.dtype) + + def _dispatch_w8a8_blockscale_op( + self, + use_cutlass: bool, + use_aiter_and_is_supported: bool, + ) -> tuple[Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], torch.Tensor], Optional[QuantFP8]]: + if use_cutlass: + return self._run_cutlass, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=False)) + if use_aiter_and_is_supported: + return self._run_aiter, None + return self._run_triton, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False)) def input_to_float8( @@ -465,7 +538,7 @@ def per_token_group_quant_fp8( @triton.jit -def _w8a8_block_fp8_matmul( +def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output A, B, @@ -590,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, return None -def w8a8_block_fp8_matmul( +def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -650,7 +723,7 @@ def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_block_fp8_matmul[grid]( + _w8a8_triton_block_scaled_mm[grid]( A, B, C, @@ -997,25 +1070,6 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, layer.weight_scale.data.T.contiguous(), requires_grad=False) -def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, - bias: Optional[torch.Tensor], - cutlass_block_fp8_supported: bool, - use_aiter_and_is_supported: bool) -> torch.Tensor: - """Apply block-wise FP8 linear operation.""" - assert layer.weight_block_size is not None - - return torch.ops.vllm.apply_w8a8_block_fp8_linear( - input=input, - weight=layer.weight, - block_size=layer.weight_block_size, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_block_fp8_supported=cutlass_block_fp8_supported, - use_aiter_and_is_supported=use_aiter_and_is_supported, - ) - - def expert_weight_is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 4083193d7650..2f533ca0639f 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -9,7 +9,7 @@ import functools import importlib import os -from typing import Any, Callable, NoReturn +from typing import Any, Callable, NoReturn, Optional import torch @@ -172,9 +172,13 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim -def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, - weight: torch.Tensor): - return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 +def should_use_deepgemm_for_fp8_linear( + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: Optional[bool] = None): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return (supports_deep_gemm and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)