diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index baedafbae99f..e69888b737e9 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ import torch from torch.library import Library -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, tag_cudagraph_unsafe # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations @@ -60,5 +60,5 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mutates_args=["out"], fake_impl=silly_attention_fake, target_lib=silly_lib, - tags=(torch._C.Tag.cudagraph_unsafe, ), + tags=tag_cudagraph_unsafe, ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a72052442..7ff064c0b968 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -25,14 +25,11 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform -from vllm.utils import GiB_bytes, direct_register_custom_op +from vllm.utils import (GiB_bytes, direct_register_custom_op, + tag_cudagraph_unsafe) logger = init_logger(__name__) USE_XFORMERS_OPS = None -try: - tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, ) -except AttributeError: - tag_cudagraph_unsafe = () # type: ignore[assignment] def check_xformers_availability(): diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index fd1c0af31269..f225b396b45a 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3472,3 +3472,11 @@ def length_from_prompt_token_ids_or_embeds( f" prompt_token_ids={prompt_token_len}" f" prompt_embeds={prompt_embeds_len}") return prompt_token_len + + +if is_torch_equal_or_newer("2.9.0.dev"): + from vllm.platforms import current_platform + tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, + ) if current_platform.is_cuda_alike() else () +else: + tag_cudagraph_unsafe = () # type: ignore[assignment]