From 2a64311ed2ee71b4eb7db92de98c83473699f322 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 25 Sep 2025 14:00:57 -0700 Subject: [PATCH] remove deepgemm register Signed-off-by: yewentao256 --- .../layers/quantization/deepgemm.py | 78 ------------------- .../layers/quantization/utils/fp8_utils.py | 17 ++-- 2 files changed, 5 insertions(+), 90 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/deepgemm.py diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py deleted file mode 100644 index 2236824ce910..000000000000 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging - -import torch - -from vllm.triton_utils import triton -from vllm.utils import direct_register_custom_op -from vllm.utils.deep_gemm import fp8_gemm_nt - -logger = logging.getLogger(__name__) - - -def prepare_block_fp8_matmul_inputs( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> tuple[int, int, int, torch.Tensor]: - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - - assert A.shape[-1] == B.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - assert A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] - - M = A.numel() // A.shape[-1] - - assert B.ndim == 2 - assert B.is_contiguous() - assert Bs.ndim == 2 - N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0] - assert triton.cdiv(K, block_k) == Bs.shape[1] - - C_shape = A.shape[:-1] + (N, ) - C = A.new_empty(C_shape, dtype=output_dtype) - - return M, N, K, C - - -def w8a8_block_fp8_matmul_deepgemm( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, - output_dtype) - # Deepgemm only supports output tensor type as bfloat16 - assert C.dtype == torch.bfloat16 - fp8_gemm_nt((A, As), (B, Bs), C) - return C - - -def w8a8_block_fp8_matmul_deepgemm_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, - output_dtype) - return C - - -direct_register_custom_op( - op_name="w8a8_block_fp8_matmul_deepgemm", - op_func=w8a8_block_fp8_matmul_deepgemm, - fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, -) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index b32c67dec7ff..b2548e66827d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -23,7 +23,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op -from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, +from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -141,17 +141,10 @@ def apply_w8a8_block_fp8_linear( block_size[1], column_major_scales=True, ) - - # ensure DeepGEMM-backed custom op is registered before use - import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 - - output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( - q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=output_dtype) + output = torch.empty((q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device) + fp8_gemm_nt((q_input, x_scale), (weight, weight_scale), output) if bias is not None: output += bias return output.to(dtype=output_dtype).view(*output_shape)