Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)

Expand Down Expand Up @@ -631,6 +629,5 @@ def unified_attention_with_output_fake(
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)
1 change: 0 additions & 1 deletion vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def call_trtllm_fused_allreduce_norm_fake(
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
Expand Down
1 change: 0 additions & 1 deletion vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def all_reduce_symmetric_with_copy_fake(
direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake,
)

Expand Down
7 changes: 0 additions & 7 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,


if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
2 changes: 0 additions & 2 deletions vllm/lora/ops/triton_ops/lora_expand_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -283,7 +282,6 @@ def _lora_expand_fake(
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_expand = torch.ops.vllm.lora_expand

Expand Down
2 changes: 0 additions & 2 deletions vllm/lora/ops/triton_ops/lora_shrink_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -237,7 +236,6 @@ def _lora_shrink_fake(
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_shrink = torch.ops.vllm.lora_shrink

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake,
)
1 change: 0 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
tags=(() if is_torch_equal_or_newer("2.7.0") else
(torch.Tag.needs_fixed_stride_order, )),
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,7 +2040,6 @@ def moe_forward_fake(
op_func=moe_forward,
mutates_args=["hidden_states"],
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

Expand Down Expand Up @@ -2071,7 +2070,6 @@ def moe_forward_shared_fake(
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,41 +223,34 @@ def rocm_aiter_fused_moe_fake(
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,13 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/mamba/linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata

Expand Down Expand Up @@ -401,5 +400,4 @@ def linear_attention_fake(
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata

Expand Down Expand Up @@ -464,5 +463,4 @@ def mamba_mixer_fake(
op_func=mamba_mixer,
mutates_args=["output"],
fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata

Expand Down Expand Up @@ -765,5 +764,4 @@ def mamba_mixer2_fake(
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
)
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/mamba/short_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
Expand Down Expand Up @@ -251,5 +250,4 @@ def short_conv_fake(
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
)
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt
Expand Down Expand Up @@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[],
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def _fused_mul_mat_gguf_fake(
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
Expand Down Expand Up @@ -273,7 +272,6 @@ def _fused_moe_gguf_fake(
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
Expand Down Expand Up @@ -319,7 +317,6 @@ def _apply_gguf_embedding_fake(
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def rocm_aiter_gemm_w8a8_fake(
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
Expand Down Expand Up @@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake(
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def _quant_dequant_mxfp4_fake(x: torch.Tensor,
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
Expand All @@ -124,7 +123,6 @@ def _quant_dequant_mxfp4_fake(x: torch.Tensor,
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
mutates_args=[],
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,5 +147,4 @@ def _flashinfer_rotary_embedding_fake(
op_func=_flashinfer_rotary_embedding,
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
dispatch_key=current_platform.dispatch_key,
)
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op

Expand Down Expand Up @@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/plamo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
Expand Down Expand Up @@ -490,7 +489,6 @@ def plamo2_mamba_mixer_fake(
op_func=plamo2_mamba_mixer,
mutates_args=["output"],
fake_impl=plamo2_mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down
Loading