diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index b84ebe95d525..7837d8cbb570 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -828,6 +828,35 @@ def CUDAExtension(name, sources, *args, **kwargs): cmdclass={ 'build_ext': BuildExtension }) + + Compute capabilities: + + By default the extension will be compiled to run on all archs of the cards visible during the + building process of the extension, plus PTX. If down the road a new card is installed the + extension may need to be recompiled. If a visible card has a compute capability (CC) that's + newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch + will make nvcc fall back to building kernels with the newest version of PTX your nvcc does + support (see below for details on PTX). + + You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which + CCs you want the extension to support: + + TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py + TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py + + The +PTX option causes extension kernel binaries to include PTX instructions for the specified + CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >= + the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with + CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to + provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on + those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better + off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6, + "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but + "8.0 8.6" would be better. + + Note that while it's possible to include all supported archs, the more archs get included the + slower the building process will be, as it will build a separate kernel image for each arch. + ''' library_dirs = kwargs.get('library_dirs', []) library_dirs += library_paths(cuda=True) @@ -1496,16 +1525,24 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: # If not given, determine what's best for the GPU / CUDA version that can be found if not _arch_list: - capability = torch.cuda.get_device_capability() - supported_sm = [int(arch.split('_')[1]) - for arch in torch.cuda.get_arch_list() if 'sm_' in arch] - max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) - # Capability of the device may be higher than what's supported by the user's - # NVCC, causing compilation error. User's NVCC is expected to match the one - # used to build pytorch, so we use the maximum supported capability of pytorch - # to clamp the capability. - capability = min(max_supported_sm, capability) - arch_list = [f'{capability[0]}.{capability[1]}'] + arch_list = [] + # the assumption is that the extension should run on any of the currently visible cards, + # which could be of different types - therefore all archs for visible cards should be included + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + supported_sm = [int(arch.split('_')[1]) + for arch in torch.cuda.get_arch_list() if 'sm_' in arch] + max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) + # Capability of the device may be higher than what's supported by the user's + # NVCC, causing compilation error. User's NVCC is expected to match the one + # used to build pytorch, so we use the maximum supported capability of pytorch + # to clamp the capability. + capability = min(max_supported_sm, capability) + arch = f'{capability[0]}.{capability[1]}' + if arch not in arch_list: + arch_list.append(arch) + arch_list = sorted(arch_list) + arch_list[-1] += '+PTX' else: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';')