Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDAExtension] support all visible cards when building a cudaextension #48891

Closed
wants to merge 8 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 50 additions & 10 deletions torch/utils/cpp_extension.py
Expand Up @@ -828,6 +828,39 @@ def CUDAExtension(name, sources, *args, **kwargs):
cmdclass={
'build_ext': BuildExtension
})

Compute capabilities:

By default the extension will be compiled to run on all archs of the cards visible during the
building process of the extension. If down the road a new card is installed the extension may
need to be recompiled.

This default may also clamp a higher arch to a lower one that has the same binary compatibility
if pytorch was compiled with that lower one. e.g. ``sm_86`` could be clamped down to `sm_80`.
For a variety of technical reasons the distributed pytorch binary doesn't build against the full
range of computer capabilities, e.g. it includes only `sm_60` and `sm_70`,but not `sm_61` and
`sm_75`. The not included ones are binary compatible with the included ones, but you might not
be getting the best performance.
stas00 marked this conversation as resolved.
Show resolved Hide resolved

However, you can explicitly specify which archs you want the extension to support like so:
stas00 marked this conversation as resolved.
Show resolved Hide resolved

TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py
TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py

The `+PTX` option is special and if provided as shown in the last example will support any card
whose compute capability was not compiled for and it'll use JIT at runtime instead (that's, of
course, if the instruction sets match - some old cards won't be possible to use.)
stas00 marked this conversation as resolved.
Show resolved Hide resolved

Notes:

- the more archs get included the slower the building process will be, as it will build a
separate kernel image for each arch
stas00 marked this conversation as resolved.
Show resolved Hide resolved

- to get the best performance it's always the best to compile for the exact compute capability
of the cards you are going to use the extension with. e.g. while sm_80 will work just fine on
a sm_86-based card, you could be missing out on the new instruction sets available to the
sm_86 card.
stas00 marked this conversation as resolved.
Show resolved Hide resolved

'''
library_dirs = kwargs.get('library_dirs', [])
library_dirs += library_paths(cuda=True)
Expand Down Expand Up @@ -1496,16 +1529,23 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:

# If not given, determine what's best for the GPU / CUDA version that can be found
if not _arch_list:
capability = torch.cuda.get_device_capability()
supported_sm = [int(arch.split('_')[1])
for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
# Capability of the device may be higher than what's supported by the user's
# NVCC, causing compilation error. User's NVCC is expected to match the one
# used to build pytorch, so we use the maximum supported capability of pytorch
# to clamp the capability.
capability = min(max_supported_sm, capability)
arch_list = [f'{capability[0]}.{capability[1]}']
arch_list = []
# the assumption is that the extension should run on any of the currently visible cards,
# which could be of different types - therefore all archs for visible cards should be included
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
supported_sm = [int(arch.split('_')[1])
for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
# Capability of the device may be higher than what's supported by the user's
# NVCC, causing compilation error. User's NVCC is expected to match the one
# used to build pytorch, so we use the maximum supported capability of pytorch
# to clamp the capability.
capability = min(max_supported_sm, capability)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
arch = f'{capability[0]}.{capability[1]}'
if arch not in arch_list:
arch_list.append(arch)
arch_list = sorted(arch_list)
else:
# Deal with lists that are ' ' separated (only deal with ';' after)
_arch_list = _arch_list.replace(' ', ';')
Expand Down