Skip to content

Commit

Permalink
[CUDAExtension] support all visible cards when building a cudaextensi…
Browse files Browse the repository at this point in the history
…on (#48891)

Summary:
Currently CUDAExtension assumes that all cards are of the same type on the same machine and builds the extension with compute capability of the 0th card. This breaks later at runtime if the machine has cards of different types.

Specifically resulting in:
```
RuntimeError: CUDA error: no kernel image is available for execution on the device
```
when the cards of the types that weren't compiled for are used. (and the error is far from telling what the problem is to the uninitiated)

My current setup is:
```
$ CUDA_VISIBLE_DEVICES=0 python -c "import torch; print(torch.cuda.get_device_capability())"
(8, 6)
$ CUDA_VISIBLE_DEVICES=1 python -c "import torch; print(torch.cuda.get_device_capability())"
(6, 1)
```
but the extension was getting built with `-gencode=arch=compute_80,code=sm_80`.

This PR:
* [x] introduces a loop over all visible at build time devices to ensure the extension will run on all of them (it sorts the new list generated by the loop, so that the output is easier to debug should a card with lower capacity come last)
* [x] adds `+PTX` to the last entry of ccs derived from local cards (`if not _arch_list:`) to support other archs
* [x] adds a digest of my conversation with ptrblck on slack in the form of docs which hopefully can help others know which archs to support, how to override defaults, when and how to add PTX, etc.

Please kindly review that my prose is clear and easy to understand.

ptrblck

Pull Request resolved: #48891

Reviewed By: ngimel

Differential Revision: D25358285

Pulled By: ezyang

fbshipit-source-id: 8160f3adebffbc8e592ddfcc3adf153a9dc91557
  • Loading branch information
stas00 authored and facebook-github-bot committed Dec 8, 2020
1 parent 6000481 commit 02b6385
Showing 1 changed file with 47 additions and 10 deletions.
57 changes: 47 additions & 10 deletions torch/utils/cpp_extension.py
Expand Up @@ -828,6 +828,35 @@ def CUDAExtension(name, sources, *args, **kwargs):
cmdclass={
'build_ext': BuildExtension
})
Compute capabilities:
By default the extension will be compiled to run on all archs of the cards visible during the
building process of the extension, plus PTX. If down the road a new card is installed the
extension may need to be recompiled. If a visible card has a compute capability (CC) that's
newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
support (see below for details on PTX).
You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which
CCs you want the extension to support:
TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py
TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py
The +PTX option causes extension kernel binaries to include PTX instructions for the specified
CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=
the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with
CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to
provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on
those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better
off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,
"8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but
"8.0 8.6" would be better.
Note that while it's possible to include all supported archs, the more archs get included the
slower the building process will be, as it will build a separate kernel image for each arch.
'''
library_dirs = kwargs.get('library_dirs', [])
library_dirs += library_paths(cuda=True)
Expand Down Expand Up @@ -1496,16 +1525,24 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:

# If not given, determine what's best for the GPU / CUDA version that can be found
if not _arch_list:
capability = torch.cuda.get_device_capability()
supported_sm = [int(arch.split('_')[1])
for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
# Capability of the device may be higher than what's supported by the user's
# NVCC, causing compilation error. User's NVCC is expected to match the one
# used to build pytorch, so we use the maximum supported capability of pytorch
# to clamp the capability.
capability = min(max_supported_sm, capability)
arch_list = [f'{capability[0]}.{capability[1]}']
arch_list = []
# the assumption is that the extension should run on any of the currently visible cards,
# which could be of different types - therefore all archs for visible cards should be included
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
supported_sm = [int(arch.split('_')[1])
for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
# Capability of the device may be higher than what's supported by the user's
# NVCC, causing compilation error. User's NVCC is expected to match the one
# used to build pytorch, so we use the maximum supported capability of pytorch
# to clamp the capability.
capability = min(max_supported_sm, capability)
arch = f'{capability[0]}.{capability[1]}'
if arch not in arch_list:
arch_list.append(arch)
arch_list = sorted(arch_list)
arch_list[-1] += '+PTX'
else:
# Deal with lists that are ' ' separated (only deal with ';' after)
_arch_list = _arch_list.replace(' ', ';')
Expand Down

0 comments on commit 02b6385

Please sign in to comment.