diff --git a/vllm/envs.py b/vllm/envs.py index b8af770d05f6..832d031f998e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -106,6 +106,8 @@ VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False + VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True @@ -934,6 +936,18 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")), + # Whether to use aiter fp4 gemm asm. + # By default is disabled. + "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in + ("true", "1")), + + # Whether to use aiter rope. + # By default is disabled. + "VLLM_ROCM_USE_TRITON_ROPE": + lambda: (os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in + ("true", "1")), + # Whether to use aiter triton fp8 bmm kernel # By default is enabled. "VLLM_ROCM_USE_AITER_FP8BMM": @@ -1539,6 +1553,8 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_RMSNORM", "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", + "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", + "VLLM_ROCM_USE_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_FP8_PADDING", diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index df5bced6b228..04a5db07e95c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -323,6 +323,12 @@ def __init__( return_bias: bool = True, disable_tp: bool = False, ): + # If MergedReplicatedLinear, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = self.output_sizes + else: + self.output_partition_sizes = [output_size] + super().__init__(input_size, output_size, skip_bias_add, @@ -335,7 +341,8 @@ def __init__( # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, - self.input_size, [self.output_size], + self.input_size, + self.output_partition_sizes, self.input_size, self.output_size, self.params_dtype, @@ -374,12 +381,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) def forward( - self, x: torch.Tensor + self, + x: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: return output return output, output_bias @@ -413,7 +423,7 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. disable_tp: If true, weights matrix won't be sharded through tp rank. """ @@ -535,13 +545,15 @@ def weight_loader_v2(self, param: BasevLLMParameter, param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward( - self, input_ + self, + input_, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -1326,7 +1338,8 @@ def weight_loader_v2(self, param: BasevLLMParameter, param.load_row_parallel_weight(loaded_weight=loaded_weight) def forward( - self, input_ + self, + input_, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: input_parallel = input_ @@ -1340,9 +1353,8 @@ def forward( # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) + output_parallel = self.quant_method.apply(self, input_parallel, bias_) + if self.reduce_results and self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index b67ee5cf453d..c65212c01819 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -395,6 +395,7 @@ def apply(self, scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 880438a22a69..f8628a82277b 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cache from typing import Any, Callable, Optional import torch import torch.nn.functional as F -from vllm.logger import init_logger +from vllm import envs from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4) @@ -14,7 +15,90 @@ PackedvLLMParameter) from vllm.platforms import current_platform -logger = init_logger(__name__) + +@cache +def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \ + and envs.VLLM_ROCM_USE_AITER + + +try: + from aiter.ops.shuffle import shuffle_weight + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + from vllm.utils import direct_register_custom_op + if is_rocm_aiter_fp4_asm_gemm_enabled(): + from aiter import gemm_a4w4, per_1x32_f4_quant_hip + + def gemm_with_dynamic_quant( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + x_scales: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + M = x.shape[0] + if rocm_use_aiter_fp4_asm_gemm: + if x_scales is None: + # use hip quant kernel for performance + x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) + else: + x_q = x + x_s = x_scales + + # 32 alignment is enough for dim0 padding of output for + # gemm_a4w4 kernel + y = torch.empty((M + 31) // 32 * 32, + weight.shape[0], + device=x_q.device, + dtype=out_dtype) + + gemm_a4w4(x_q, + weight, + x_s, + weight_scale.view(x_s.dtype), + y, + bpreshuffle=True) + return y[:M] + else: + if x_scales is None: + x_q, x_s = dynamic_mxfp4_quant(x) + else: + x_q = x + x_s = x_scales + y = torch.empty(x_q.shape[0], + weight.shape[0], + device=x_q.device, + dtype=out_dtype) + + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) + return y + + def gemm_with_dynamic_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + x_scales: torch.Tensor = None, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + ) -> torch.Tensor: + return torch.empty((*x.shape[:-1], weight.shape[0]), + dtype=out_dtype, + device=x.device) + + direct_register_custom_op( + op_name="gemm_with_dynamic_quant", + op_func=gemm_with_dynamic_quant, + mutates_args=[], + fake_impl=gemm_with_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + +except ImportError: + dynamic_mxfp4_quant = gemm_afp4wfp4 = None __all__ = ["QuarkW4A4MXFP4"] @@ -27,29 +111,15 @@ def __init__(self, weight_quant_spec: dict[str, Any], self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec - - self.static_input_scales = not input_quant_spec.get("is_dynamic") - - if self.static_input_scales: + self.emulate = not current_platform.supports_mx() + self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() + if not self.emulate and (dynamic_mxfp4_quant is None + or gemm_afp4wfp4 is None): + # Currently need these kernels if not emulating raise NotImplementedError( - "QuarkW4A4MXFP4 with static input scales is currently not " - "implemented. Please open an issue.") - - if not current_platform.supports_mx(): - self.emulate = True - logger.warning_once( - "The current platform does not support native MXFP4 " - "computation. Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") - else: - self.emulate = True - logger.warning_once( - "The current platform supports native MXFP4 " - "computation, but kernels are not yet integrated in vLLM. " - "Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") + f"{self.__class__.__name__} requires AITER to be installed " + "for non-emulation mode! Please refer to " + "https://github.com/ROCm/aiter for installation details.") @classmethod def get_min_capability(cls) -> int: @@ -58,8 +128,65 @@ def get_min_capability(cls) -> int: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) + + if self.emulate: + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + weight_quant_spec = QuantizationSpec.from_dict( + self.weight_quant_spec) + + weight_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, + float_dtype=self.out_dtype, + scale_shape=layer.weight_scale.shape, + zero_point_shape=None, + ) + weight_quantizer.scale.data = layer.weight_scale.data + + layer.weight = torch.nn.Parameter( + weight_quantizer(layer.weight.data).to(self.out_dtype), + requires_grad=False, + ) + layer.weight_scale = None + + # This call is necessary to release the scales memory. + torch.cuda.empty_cache() + else: + if self.rocm_use_aiter_fp4_asm_gemm: + # shuffle weight scale + weight_scale_shuffle = layer.weight_scale.data + sm, sn = weight_scale_shuffle.shape + weight_scale_shuffle = weight_scale_shuffle.view( + sm // 32, 2, 16, sn // 8, 2, 4, 1) + weight_scale_shuffle = weight_scale_shuffle.permute( + 0, 3, 5, 2, 4, 1, 6).contiguous() + weight_scale_shuffle = weight_scale_shuffle.view(sm, sn) + layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle, + requires_grad=False) + + # shuffle weight + weight_shuffle = layer.weight.data + weight_shuffle = shuffle_weight(weight_shuffle, + layout=(16, 16)) + layer.weight = torch.nn.Parameter(weight_shuffle, + requires_grad=False) + else: + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), + requires_grad=False) def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], @@ -104,9 +231,9 @@ def apply_weights(self, if self.emulate: dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype) - x = quant_dequant_mxfp4(x) - return F.linear(x, dq_w, bias) else: - raise NotImplementedError() + return torch.ops.vllm.gemm_with_dynamic_quant( + x, layer.weight, layer.weight_scale, + self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 1c3576bee539..0cf634f82a8a 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -8,6 +8,8 @@ from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch +from .rocm_aiter_rope_ops import (is_rocm_triton_rotary_embedding_enabled, + rocm_aiter_rotary_emb) @CustomOp.register("rotary_embedding") @@ -45,6 +47,8 @@ def __init__( cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) + self.is_rocm_triton_rotary_embedding_enabled = \ + is_rocm_triton_rotary_embedding_enabled() def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -120,14 +124,31 @@ def forward_cuda( return query, key from vllm import _custom_ops as ops - self._match_cos_sin_cache_dtype(query) + # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. ops.rotary_embedding(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style) return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.is_rocm_triton_rotary_embedding_enabled: + self._match_cos_sin_cache_dtype(query) + rocm_aiter_rotary_emb(positions, query, key, self.cos_sin_cache, + self.head_size, self.rotary_dim, + self.is_neox_style) + else: + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + self.forward_cuda(positions, query, key) + return query, key + def forward_xpu( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py new file mode 100644 index 000000000000..da7c84cb442d --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def is_rocm_triton_rotary_embedding_enabled() -> bool: + return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_TRITON_ROPE) + + +def rocm_aiter_rotary_emb_with_key_forward_triton_impl( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + import aiter.ops.triton.rope as ops + ops.rope_cached_thd_positions_2c_fwd_inplace( + query, + key, + cos, + sin, + positions, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + + +def rocm_aiter_rotary_emb_with_key_forward_triton_fake( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + pass + + +if is_rocm_triton_rotary_embedding_enabled(): + + direct_register_custom_op( + op_name="rocm_aiter_rotary_emb_with_key_forward_triton", + op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, + mutates_args=["key", "query"], + fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def rocm_aiter_rotary_emb(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, cos_sin_cache: torch.Tensor, + head_size: int, rotary_dim: int, + is_neox_style: bool): + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox_style else 1 + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton( + positions, + sin, + cos, + query_, + key_, + rotate_style, + False, + ) + query = query.view(query_shape) + key = key.view(key_shape)