diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index c3e99e177e2d..b6ecc3392bd3 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -31,6 +31,7 @@ from torch.distributed import ReduceOp from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import find_nccl_library logger = init_logger(__name__) @@ -223,6 +224,9 @@ class NCCLLibrary: Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); Function("ncclGroupEnd", ncclResult_t, []), + ] + + exported_functions_cuda_specific = [ # ncclResult_t ncclCommWindowRegister( # ncclComm_t comm, void* buff, size_t size, # ncclWindow_t* win, int winFlags); @@ -271,10 +275,12 @@ def __init__(self, so_file: Optional[str] = None): " to point to the correct nccl library path.", so_file, platform.platform()) raise e - + function_specs = list(NCCLLibrary.exported_functions) + if current_platform.is_cuda(): + function_specs.extend(NCCLLibrary.exported_functions_cuda_specific) if so_file not in NCCLLibrary.path_to_dict_mapping: _funcs: dict[str, Any] = {} - for func in NCCLLibrary.exported_functions: + for func in function_specs: f = getattr(self.lib, func.name) f.restype = func.restype f.argtypes = func.argtypes