diff --git a/setup.py b/setup.py index 2644c03b1637..8cb0d7bbbd68 100644 --- a/setup.py +++ b/setup.py @@ -47,12 +47,6 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: raise RuntimeError( "GPUs with compute capability less than 7.0 are not supported.") compute_capabilities.add(major * 10 + minor) -# If no GPU is available, add all supported compute capabilities. -if not compute_capabilities: - compute_capabilities = {70, 75, 80, 86, 90} -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] # Validate the NVCC CUDA version. nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) @@ -65,6 +59,18 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: raise RuntimeError( "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") +# If no GPU is available, add all supported compute capabilities. +if not compute_capabilities: + compute_capabilities = {70, 75, 80} + if nvcc_cuda_version >= Version("11.1"): + compute_capabilities.add(86) + if nvcc_cuda_version >= Version("11.8"): + compute_capabilities.add(90) + +# Add target compute capabilities to NVCC flags. +for capability in compute_capabilities: + NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] + # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): num_threads = min(os.cpu_count(), 8)