diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a72052442..8db1cc769ba2 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -578,9 +578,7 @@ def unified_attention_fake( direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, - mutates_args=[], fake_impl=unified_attention_fake, - dispatch_key=current_platform.dispatch_key, tags=tag_cudagraph_unsafe, ) @@ -631,6 +629,5 @@ def unified_attention_with_output_fake( op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, - dispatch_key=current_platform.dispatch_key, tags=tag_cudagraph_unsafe, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 331cd8a87392..04b76a9c2d22 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -547,7 +547,6 @@ def call_trtllm_fused_allreduce_norm_fake( "scale_out", ], fake_impl=call_trtllm_fused_allreduce_norm_fake, - dispatch_key=current_platform.dispatch_key, ) flashinfer_trtllm_fused_allreduce_norm = ( torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 75de85e1b0ab..76fe9a93259f 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -46,7 +46,6 @@ def all_reduce_symmetric_with_copy_fake( direct_register_custom_op( op_name="all_reduce_symmetric_with_copy", op_func=all_reduce_symmetric_with_copy_impl, - mutates_args=[], fake_impl=all_reduce_symmetric_with_copy_fake, ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 895971893a66..69f98eb54f36 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, if supports_custom_op(): - from vllm.platforms import current_platform direct_register_custom_op( op_name="all_reduce", op_func=all_reduce, - mutates_args=[], fake_impl=all_reduce_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="reduce_scatter", op_func=reduce_scatter, - mutates_args=[], fake_impl=reduce_scatter_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="all_gather", op_func=all_gather, - mutates_args=[], fake_impl=all_gather_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index b1ab84e08ba7..467cbaa8af48 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -11,7 +11,6 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -283,7 +282,6 @@ def _lora_expand_fake( op_func=_lora_expand, mutates_args=["output_tensor"], fake_impl=_lora_expand_fake, - dispatch_key=current_platform.dispatch_key, ) lora_expand = torch.ops.vllm.lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 1e7075ab0715..57da93c226d2 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -11,7 +11,6 @@ from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -237,7 +236,6 @@ def _lora_shrink_fake( op_func=_lora_shrink, mutates_args=["output_tensor"], fake_impl=_lora_shrink_fake, - dispatch_key=current_platform.dispatch_key, ) lora_shrink = torch.ops.vllm.lora_shrink diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index e358143fac7c..fe586a22e250 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake( direct_register_custom_op( op_name="flashinfer_fused_moe_blockscale_fp8", op_func=flashinfer_fused_moe_blockscale_fp8, - mutates_args=[], fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, tags=(torch.Tag.needs_fixed_stride_order, ), ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1e3ac6cd79f6..eb12a9b0a233 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, direct_register_custom_op( op_name="fused_marlin_moe", op_func=fused_marlin_moe, - mutates_args=[], fake_impl=fused_marlin_moe_fake, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0e334fdf2404..611df357265b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake( direct_register_custom_op( op_name="outplace_fused_experts", op_func=outplace_fused_experts, - mutates_args=[], fake_impl=outplace_fused_experts_fake, tags=(() if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order, )), diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 71cc2bcf174d..2bf3bf96baf1 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2040,7 +2040,6 @@ def moe_forward_fake( op_func=moe_forward, mutates_args=["hidden_states"], fake_impl=moe_forward_fake, - dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) @@ -2071,7 +2070,6 @@ def moe_forward_shared_fake( op_func=moe_forward_shared, mutates_args=["hidden_states"], fake_impl=moe_forward_shared_fake, - dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index f4972ff5f9cb..2764af5fc532 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -223,17 +223,13 @@ def rocm_aiter_fused_moe_fake( direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=rocm_aiter_asm_moe_tkw1_impl, - mutates_args=[], fake_impl=rocm_aiter_asm_moe_tkw1_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_fused_moe", op_func=rocm_aiter_fused_moe_impl, - mutates_args=[], fake_impl=rocm_aiter_fused_moe_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -241,7 +237,6 @@ def rocm_aiter_fused_moe_fake( op_func=rocm_aiter_topk_softmax_impl, mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], fake_impl=rocm_aiter_topk_softmax_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -249,7 +244,6 @@ def rocm_aiter_fused_moe_fake( op_func=rocm_aiter_biased_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=rocm_aiter_biased_grouped_topk_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -257,7 +251,6 @@ def rocm_aiter_fused_moe_fake( op_func=rocm_aiter_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=rocm_aiter_grouped_topk_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index f875f712ba9c..8123259d037b 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -103,17 +103,13 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_fake( direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=rocm_aiter_rms_norm_impl, - mutates_args=[], fake_impl=rocm_aiter_rms_norm_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, - mutates_args=[], fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 6a901b47b8b6..410cbef4f6bc 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -31,7 +31,6 @@ MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata @@ -401,5 +400,4 @@ def linear_attention_fake( op_func=linear_attention, mutates_args=["output"], fake_impl=linear_attention_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a56ee13a6380..d64854cdb381 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -464,5 +463,4 @@ def mamba_mixer_fake( op_func=mamba_mixer, mutates_args=["output"], fake_impl=mamba_mixer_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 047ce4c4c43d..908ea6e0025f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -34,7 +34,6 @@ from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -765,5 +764,4 @@ def mamba_mixer2_fake( op_func=mamba_mixer2, mutates_args=["output"], fake_impl=mamba_mixer2_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index ffdcd702aab4..cc424760e229 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -21,7 +21,6 @@ MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ( ShortConvAttentionMetadata) @@ -251,5 +250,4 @@ def short_conv_fake( op_func=short_conv, mutates_args=["output"], fake_impl=short_conv_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index c2b3ccf19fca..8452f686b3ac 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -4,7 +4,6 @@ import torch -from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import fp8_gemm_nt @@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake( direct_register_custom_op( op_name="w8a8_deepgemm_block_scaled_mm", op_func=w8a8_deepgemm_block_scaled_mm, - mutates_args=[], fake_impl=w8a8_deepgemm_block_scaled_mm_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a631dfdab654..de25ee84d081 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -161,7 +161,6 @@ def _fused_mul_mat_gguf_fake( direct_register_custom_op( op_name="_fused_mul_mat_gguf", op_func=_fused_mul_mat_gguf, - mutates_args=[], fake_impl=_fused_mul_mat_gguf_fake, ) fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf @@ -273,7 +272,6 @@ def _fused_moe_gguf_fake( direct_register_custom_op( op_name="_fused_moe_gguf", op_func=_fused_moe_gguf, - mutates_args=[], fake_impl=_fused_moe_gguf_fake, ) fused_moe_gguf = torch.ops.vllm._fused_moe_gguf @@ -319,7 +317,6 @@ def _apply_gguf_embedding_fake( direct_register_custom_op( op_name="_apply_gguf_embedding", op_func=_apply_gguf_embedding, - mutates_args=[], fake_impl=_apply_gguf_embedding_fake, ) apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 7f808fa92a9a..e8e950a4bb7b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -51,9 +51,7 @@ def rocm_aiter_gemm_w8a8_fake( direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8", op_func=rocm_aiter_gemm_w8a8_impl, - mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2098086bf240..0bc69fe7f930 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -91,9 +91,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8_blockscale", op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, ) if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR and current_platform.is_fp8_fnuz()): @@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake( direct_register_custom_op( "w8a8_triton_block_scaled_mm_func", _w8a8_triton_block_scaled_mm_func, - mutates_args=[], fake_impl=_w8a8_triton_block_scaled_mm_fake, dispatch_key="CUDA", ) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 3de928fea720..d61ca7ad5dc4 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -113,7 +113,6 @@ def _quant_dequant_mxfp4_fake(x: torch.Tensor, direct_register_custom_op( op_name="dequant_mxfp4", op_func=_dequant_mxfp4, - mutates_args=[], fake_impl=_dequant_mxfp4_fake, ) dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 @@ -124,7 +123,6 @@ def _quant_dequant_mxfp4_fake(x: torch.Tensor, direct_register_custom_op( op_name="quant_dequant_mxfp4", op_func=_quant_dequant_mxfp4, - mutates_args=[], fake_impl=_quant_dequant_mxfp4_fake, ) quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 6ed482db4700..b434b7acfea8 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, direct_register_custom_op( op_name="rocm_per_tensor_w8a8_scaled_mm_impl", op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - mutates_args=[], fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index e3cd0a8e788e..861965106774 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -147,5 +147,4 @@ def _flashinfer_rotary_embedding_fake( op_func=_flashinfer_rotary_embedding, mutates_args=["query", "key"], # These tensors are modified in-place fake_impl=_flashinfer_rotary_embedding_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index d7a65d43c210..96dd58c0e4d2 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, direct_register_custom_op( op_name="rocm_unquantized_gemm_impl", op_func=rocm_unquantized_gemm_impl, - mutates_args=[], fake_impl=rocm_unquantized_gemm_impl_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 415d36c681d8..9895ebbcdefe 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -56,7 +56,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op @@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: direct_register_custom_op( op_name="sequence_parallel_chunk", op_func=sequence_parallel_chunk, - mutates_args=[], fake_impl=sequence_parallel_chunk_fake, - dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 0292f3bf8317..a7acf64f302b 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -48,7 +48,6 @@ is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -490,7 +489,6 @@ def plamo2_mamba_mixer_fake( op_func=plamo2_mamba_mixer, mutates_args=["output"], fake_impl=plamo2_mamba_mixer_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ab23b494e561..356b5001a7dc 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1225,7 +1225,6 @@ def gdn_attention_fake( op_func=gdn_attention, mutates_args=["output"], fake_impl=gdn_attention_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 5d165f166238..0a7af79f7a17 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2546,10 +2546,10 @@ def __getattr__(self, key: str): def direct_register_custom_op( op_name: str, op_func: Callable, - mutates_args: list[str], + mutates_args: Optional[list[str]] = None, fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, - dispatch_key: str = "CUDA", + dispatch_key: Optional[str] = None, tags: tuple[torch.Tag, ...] = (), ): """ @@ -2577,6 +2577,13 @@ def direct_register_custom_op( "the required dependencies.") return + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + dispatch_key = current_platform.dispatch_key + import torch.library if hasattr(torch.library, "infer_schema"): schema_str = torch.library.infer_schema(op_func,