Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
Expand Down Expand Up @@ -934,6 +936,18 @@ def get_vllm_port() -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),

# Whether to use aiter fp4 gemm asm.
# By default is disabled.
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in
("true", "1")),

# Whether to use aiter rope.
# By default is disabled.
"VLLM_ROCM_USE_TRITON_ROPE":
lambda: (os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in
("true", "1")),

# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
Expand Down Expand Up @@ -1539,6 +1553,8 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",
Expand Down
28 changes: 20 additions & 8 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]

super().__init__(input_size,
output_size,
skip_bias_add,
Expand All @@ -335,7 +341,8 @@ def __init__(
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_partition_sizes,
self.input_size,
self.output_size,
self.params_dtype,
Expand Down Expand Up @@ -374,12 +381,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)

def forward(
self, x: torch.Tensor
self,
x: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None

output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None

if not self.return_bias:
return output
return output, output_bias
Expand Down Expand Up @@ -413,7 +423,7 @@ class ColumnParallelLinear(LinearBase):
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
Expand Down Expand Up @@ -535,13 +545,15 @@ def weight_loader_v2(self, param: BasevLLMParameter,
param.load_column_parallel_weight(loaded_weight=loaded_weight)

def forward(
self, input_
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)

if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
Expand Down Expand Up @@ -1326,7 +1338,8 @@ def weight_loader_v2(self, param: BasevLLMParameter,
param.load_row_parallel_weight(loaded_weight=loaded_weight)

def forward(
self, input_
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
Expand All @@ -1340,9 +1353,8 @@ def forward(
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output_parallel = self.quant_method.apply(self, input_parallel, bias_)

if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def apply(self,
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")

return scheme.apply_weights(layer, x, bias=bias)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from functools import cache
from typing import Any, Callable, Optional

import torch
import torch.nn.functional as F

from vllm.logger import init_logger
from vllm import envs
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform

logger = init_logger(__name__)

@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
and envs.VLLM_ROCM_USE_AITER


try:
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant

from vllm.utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip

def gemm_with_dynamic_quant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
x_scales: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M = x.shape[0]
if rocm_use_aiter_fp4_asm_gemm:
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q = x
x_s = x_scales

# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty((M + 31) // 32 * 32,
weight.shape[0],
device=x_q.device,
dtype=out_dtype)

gemm_a4w4(x_q,
weight,
x_s,
weight_scale.view(x_s.dtype),
y,
bpreshuffle=True)
return y[:M]
else:
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(x_q.shape[0],
weight.shape[0],
device=x_q.device,
dtype=out_dtype)

gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y

def gemm_with_dynamic_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)

direct_register_custom_op(
op_name="gemm_with_dynamic_quant",
op_func=gemm_with_dynamic_quant,
mutates_args=[],
fake_impl=gemm_with_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)

except ImportError:
dynamic_mxfp4_quant = gemm_afp4wfp4 = None

__all__ = ["QuarkW4A4MXFP4"]

Expand All @@ -27,29 +111,15 @@ def __init__(self, weight_quant_spec: dict[str, Any],
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec

self.static_input_scales = not input_quant_spec.get("is_dynamic")

if self.static_input_scales:
self.emulate = not current_platform.supports_mx()
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
if not self.emulate and (dynamic_mxfp4_quant is None
or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating
raise NotImplementedError(
"QuarkW4A4MXFP4 with static input scales is currently not "
"implemented. Please open an issue.")

if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
f"{self.__class__.__name__} requires AITER to be installed "
"for non-emulation mode! Please refer to "
"https://github.com/ROCm/aiter for installation details.")

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -58,8 +128,65 @@ def get_min_capability(cls) -> int:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)

if self.emulate:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err

weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)

weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data

layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
layer.weight_scale = None

# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
Comment on lines +132 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I insist that this is unnecessary https://github.com/vllm-project/vllm/pull/25135/files#r2378191214 - was not able to reopen the thread that was closed unfortunately.

else:
if self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
weight_scale_shuffle = weight_scale_shuffle.view(
sm // 32, 2, 16, sn // 8, 2, 4, 1)
weight_scale_shuffle = weight_scale_shuffle.permute(
0, 3, 5, 2, 4, 1, 6).contiguous()
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle,
requires_grad=False)

# shuffle weight
weight_shuffle = layer.weight.data
weight_shuffle = shuffle_weight(weight_shuffle,
layout=(16, 16))
layer.weight = torch.nn.Parameter(weight_shuffle,
requires_grad=False)
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(),
requires_grad=False)

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
Expand Down Expand Up @@ -104,9 +231,9 @@ def apply_weights(self,

if self.emulate:
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)

x = quant_dequant_mxfp4(x)

return F.linear(x, dq_w, bias)
else:
raise NotImplementedError()
return torch.ops.vllm.gemm_with_dynamic_quant(
x, layer.weight, layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)
Loading