diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3d1269c0ecea..544a72052442 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -29,6 +29,10 @@ logger = init_logger(__name__) USE_XFORMERS_OPS = None +try: + tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, ) +except AttributeError: + tag_cudagraph_unsafe = () # type: ignore[assignment] def check_xformers_availability(): @@ -577,7 +581,7 @@ def unified_attention_fake( mutates_args=[], fake_impl=unified_attention_fake, dispatch_key=current_platform.dispatch_key, - tags=(torch._C.Tag.cudagraph_unsafe, ), + tags=tag_cudagraph_unsafe, ) @@ -628,5 +632,5 @@ def unified_attention_with_output_fake( mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, - tags=(torch._C.Tag.cudagraph_unsafe, ), + tags=tag_cudagraph_unsafe, )