Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++17 for PyTorch #56055

Closed
smessmer opened this issue Apr 14, 2021 · 14 comments
Closed

C++17 for PyTorch #56055

smessmer opened this issue Apr 14, 2021 · 14 comments
Labels
better-engineering Relatively self-contained tasks for better engineering contributors module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@smessmer
Copy link
Contributor

smessmer commented Apr 14, 2021

We are planning to migrate the PyTorch codebase to C++17, but are currently blocked by CUDA. This issue summarizes a discussion I had with @malfet and @ngimel about this.

CUDA 11 is the first CUDA version to support C++17, so we'd have to drop support for CUDA 10, but there are good reasons for us to keep CUDA 10 around still:

  1. A PyTorch build with CUDA 10 takes around 800MB, while a PyTorch build with CUDA 11 takes 1.9GB. This is a lot of data to put in a wheel. The size difference is both due to CUDA 11 libraries generally being larger, and also due to CUDA 11 wheels supporting 2 additional architectures, sm_80 and sm_86.
  2. PyTorch wheels and conda packages with CUDA 11 are slower than CUDA 10 for the architectures they both support (sm_75) because fast kernels are dropped due to static linking. Additionally, our CUDA 11 wheels are slower than dynamically linked CUDA 11 builds on Ampere (sm86) because of the dropped kernels, but that’s a separate issue that needs to be solved.

These issues would have to be fixed in CUDA, there's not much we can do on the PyTorch side.

Workarounds we considered (kudos to @malfet and @ngimel for the ideas)

  • Only build .cpp files with C++17 while building .cu files in C++14 mode.
    • Since .cu files includes a lot of our core headers, this means we'd still be restricted to C++14 in many parts of the codebase. Furthermore, if you accidentally use a C++17 feature right now and do a local CPU build, in the current world you get a local compiler error while in the dual C++14/C++17 world, it would only fail on CI. This would decrease developer efficiency and productivity.
  • Stay with CUDA 10 but invoke nvcc with -std=c++14 '-Xcompiler -std=c++17'
    • This could work for some C++17 features, namely ones that don't change C++ syntax, but C++17 does introduce a couple of features that change the C++ syntax. This would likely trip up nvcc when it's trying to parse the C++ code. Even if it worked today, this could be a fragile setup and break in the future.
  • Use a Frankenstein CUDA where most components are from CUDA 10, but nvcc and cudafe are from CUDA 11.
    • This would add a significant maintenance burden. Also, it would likely only work in Facebook where we fully control the build. It would leave people behind that build in open source with the public available CUDA distribution, unless they use the Frankenstein toolkit too, and we don't want to be in the business of distributing that. And leaving oss users behind is not an option.
  • Split the wheel
    • Instead of distributing PyTorch in one wheel, we could have users download two wheels, e.g. one for PyTorch and one for CUDA, or maybe one for PyTorch CPU and one for PyTorch CUDA kernels + CUDA library. This could alleviate the binary size problem for some users, namely those who don't use CUDA, but GPU users would still have to download both. Distributing CUDA separately could work well for conda, but could pose issues with pip.
  • Link CUDA dynamically
    • At least some of the perf issues in CUDA 11 seem to be due to cuDNN not using fast kernels if linked statically. By linking dynamically, we could fix those perf problems. However, that would make the binary size problem even worse since dynamically linking CUDA would prevent the linker from stripping away unused parts of it. Also, we wouldn't be able to link it with -fvisibility=hidden, so it could cause symbol conflicts with other libraries.
  • Change the way we integrate CPU/CUDA, i.e. stop using triple chevrons cuda_kernel<<<1, 1>>>(arguments, argument, argument)
    • If we migrate our CUDA code away from triple chevrons towards cuLaunchKernel, then we would be able to avoid nvcc for code outside of CUDA kernels, which would allow us to write it in C++17. However, this would be a nontrivial engineering effort. Also, kineto isn't correctly recording kernels launched with cuLaunchKernel, so that would have to be fixed too.

cc @malfet @seemethere @walterddr @ngimel

@soumith soumith added module: cuda Related to torch.cuda, and CUDA support in general and removed module: cuda Related to torch.cuda, and CUDA support in general labels Apr 14, 2021
@mruberry mruberry added better-engineering Relatively self-contained tasks for better engineering contributors module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 15, 2021
@rgommers
Copy link
Collaborator

Split the wheel

This was also my first thought. It doesn't help with total download size, but at least it would help stay under 1 GB per wheel, which is the absolute max PyPI will allow. It looks like it's just possible to split in two about equally sized parts that would each stay under 1 GB without changing how .so's are built. And making torch a namespace package or metapackage it should be possible to make this work in a way that won't change the end user workflow.

Package sizes

To get an impression of file sizes, I unpacked a nightly wheel. It's 1.9 GB as a .whl and 3.9 GB unpacked. Here are the files in torch/lib/ that are >2 MB:

$ ls -l --block-size=M lib
total 3912M
   37M Mar 24 05:11 libcaffe2_detectron_ops_gpu.so
   22M Feb 11 09:56 libnvrtc-08c4863f.so.10.2
   32M Mar 24 05:12 libnvrtc-3a20f2b6.so.11.1
  330M Mar 24 05:11 libtorch_cpu.so
 1887M Mar 24 05:11 libtorch_cuda_cpp.so
 1350M Mar 24 05:11 libtorch_cuda_cu.so
  227M Mar 24 05:11 libtorch_cuda.so
   22M Mar 24 05:12 libtorch_python.so

Then I did the same for a conda nightly package. It's 1.3 GB packaged, and 3.1 GB unpacked.

$ ls -l --block-size=M lib
total 3130M
   45M Apr 15 10:57 libcaffe2_detectron_ops_gpu.so
  182M Apr 15 10:57 libtorch_cpu.so
 1556M Apr 15 10:57 libtorch_cuda_cpp.so
 1323M Apr 15 10:57 libtorch_cuda_cu.so
    1M Apr 15 10:57 libtorch_cuda.so
   23M Apr 15 10:57 libtorch_python.so
    1M Apr 15 10:57 libtorch.so

Dynamic vs static linking

By linking dynamically, we could fix those perf problems. However, that would make the binary size problem even worse since dynamically linking CUDA would prevent the linker from stripping away unused parts of it

I'm not sure I understand this. Dynamic linking should be much better for package size. Just comparing the conda packages and wheels shows that. Also compare the conda-forge pytorch packages which do everything dynamically and are 425 MB for cuda10.2 versus 700 MB for the pytorch nightlies which I believe still use static linking partially. There may be other differences, but I think it's mostly dynamic linking - I have checked that conda-forge uses the default cuDNN version for each CUDA version and the same TORCH_CUDA_ARCH_LIST as pytorch/builder. These settings in the recipe probably explain the difference (both are 1 in pytorch/builder):

    export USE_STATIC_NCCL=0
    export USE_STATIC_CUDNN=0

If this is based on "but then we have to bundle in dynamic libs from CUDA itself", then I don't think that is the right comparison - doing that is not even allowed by the CUDA EULA as far as I know.

Other ideas

  • Build for fewer GPU architectures in a single wheel

    • Right now the build includes: TORCH_CUDA_ARCH_LIST="3.7+PTX;5.0;6.0;6.1;7.0;7.5;8.0;8.6". That's a lot of architectures. I'm not sure if it's common to have systems with a mix of very old and brand new GPUs installed - that seems like something that could be unsupported or require a custom wheel that's not hosted on PyPI. The widget at https://pytorch.org/get-started/locally/ includes the CUDA version now; it could additionally include a selection for architectures (e.g., "Volta and up" as default).
  • Break up pytorch further

    • This could be done per component (e.g. torch, torch_distributed, torch_nn, caffe2) or by using some of the build switches (USE_DISTRIBUTED and other USE_* flags). I'd wager that the average user who installs from PyPI would not need the distributed backend. The "extras" mechanism can be used to select components, for example pip install torch is a slimmed down version, and pip install torch[all] gets you all the wheels.
  • Link MKL dynamically

    • This is a smaller gain, but the ~180 MB of libtorch_cpu.so is in significant part due to MKL I think. Push Intel to maintain https://pypi.org/project/mkl/#files a bit better and use it as a runtime dependency.

Longer term & related ideas

Right now only a single CUDA version can be put on PyPI. Currently that is CUDA 10.2; CUDA 11.1, ROCm 4.0 and CPU-only are all in a separate wheelhouse like https://download.pytorch.org/whl/torch_stable.html. That makes it very difficult to depend on, in effect distributing a downstream package depending on those is impractical. This should be solved somehow - and if the solution involves making it easier/better to do package hosting outside of PyPI (e.g. via custom wheelhouses or interacting with another package manager) then the limitations of PyPI may become less of a limiting factor here as well.

Relevant discussion and blog post:

My personal impression is that the problem is only going to get worse (CUDA continues to grow, next year we're likely to get sm_90), and there is currently no clear plan for PyPI/pip/wheels to deal with even CUDA 11. ~2 GB wheels are not a healthy thing. A pragmatic solution could be:

  • Only put the cpuonly wheel on PyPI
  • Build a mechanism where downstream packages can declare a dependency on a package from another package manager. Not just Conda but also Homebrew and Linux distro package managers are able to package CUDA and use sane dynamic linking. So a mapping where a pyproject.toml could have:
# Dependencies with compiled code including CUDA
requires_external = [
    "cuda==11",
    "torch==1.8.1",
    "torchvision==0.9.1",
    "numpy",
]

# Pure Python dependencies
requires = [
    "pytest",
    "imgaug",
]

And then cuda could map to cudatoolkit for conda, nvidia-cuda-toolkit for apt-get, etc. Not necessarily to invoke another package manager, but perhaps simply check if those things are already installed, and if they aren't then emit a helpful message to the user.

Possible next step

None of these solutions are ideal, there's no clear winner imho. Without more information, my guess would be that slimming down TORCH_CUDA_ARCH_LIST is the easiest thing to do and gains the most.

It would be great to have a better idea of what each technical change would bring. Have something like a "package size budget" where the total size is broken down into contributions from each dependency and build option. Does anyone have anything like this, and if not would it be worth producing it?

@smessmer
Copy link
Contributor Author

Thanks for the thorough analysis, there's some great ideas in there. cc @ngimel @malfet

@rgommers
Copy link
Collaborator

Build for fewer GPU architectures in a single wheel

Just noticed that gh-49050 (which split torch_cuda into torch_cuda_cu and torch_cuda_cpp) has a short discussion by @ezyang and @janeyx99 on the same idea, #49050 (comment)

@ptrblck
Copy link
Collaborator

ptrblck commented Apr 27, 2021

Thanks for the detailed analysis! CC @eqy who is looking into the memory requirements as well.

@t-vi
Copy link
Collaborator

t-vi commented Apr 28, 2021

One thing about dropping CUDA 10: The Jetsons currently are incompatible with CUDA 11 (as they don't use the PCIe and so they need their own driver that has not yet been updated as far as I understand).

@swolchok
Copy link
Contributor

Just to keep a list, here are some performance reasons to switch to C++17:

@xloem
Copy link

xloem commented Jul 4, 2022

Note: PyTorch 1.12 uses aligned_alloc which is a c++17 symbol according to cppreference.com

@cyyever
Copy link
Collaborator

cyyever commented Jul 11, 2022

@xloem So pytorch requires C++17 in fact.

@xloem
Copy link

xloem commented Jul 11, 2022

@xloem So pytorch requires C++17 in fact.

?

- A C++14 compatible compiler, such as clang

CXX_STANDARD 14

@daquexian
Copy link
Contributor

I believe the aligned_alloc used in pytorch is aligned_alloc in c11 but not std::aligned_alloc in c++17 https://en.cppreference.com/w/c/memory/aligned_alloc

@miladm
Copy link
Collaborator

miladm commented Jul 26, 2022

FWIW, as a framework dependency, PyTorch/XLA is now transitioning to c++17 (Reference PR).

CC @wconstab @ezyang @JackCaoG

@xloem
Copy link

xloem commented Jul 27, 2022

I believe the aligned_alloc used in pytorch is aligned_alloc in c11 but not std::aligned_alloc in c++17 https://en.cppreference.com/w/c/memory/aligned_alloc

Just to follow up here, I haven't looked into the details of correctness. Some time ago I attempted to compile pytorch with clang, and the aligned_alloc code would not compile until I enabled c++17. I do not presently have the llvm version in question, I'm afraid.

The aligned_alloc concern should be in a separate issue, and I apologise for clogging this thread. Since I am the only person who has mentioned it, it is probably minor.

malfet added a commit that referenced this issue Sep 30, 2022
malfet added a commit that referenced this issue Sep 30, 2022
malfet added a commit that referenced this issue Sep 30, 2022
malfet added a commit that referenced this issue Oct 1, 2022
malfet added a commit that referenced this issue Oct 12, 2022
malfet added a commit that referenced this issue Oct 26, 2022
malfet added a commit that referenced this issue Oct 27, 2022
malfet added a commit that referenced this issue Oct 28, 2022
malfet added a commit that referenced this issue Nov 10, 2022
malfet added a commit that referenced this issue Nov 19, 2022
malfet added a commit that referenced this issue Nov 23, 2022
malfet added a commit that referenced this issue Nov 24, 2022
malfet added a commit that referenced this issue Nov 29, 2022
malfet added a commit that referenced this issue Dec 5, 2022
malfet added a commit that referenced this issue Dec 6, 2022
malfet added a commit that referenced this issue Dec 7, 2022
malfet added a commit that referenced this issue Dec 7, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
With CUDA-10.2 gone we can finally do it!

This PR mostly contains build system related changes, invasive functional ones are to be followed.
Among many expected tweaks to the build system, here are few unexpected ones:
 - Force onnx_proto project to be updated to C++17 to avoid `duplicate symbols` error when compiled by gcc-7.5.0, as storage rule for `constexpr` changed in C++17, but gcc does not seem to follow it
 - Do not use `std::apply` on CUDA but rely on the built-in variant, as it results in test failures when CUDA runtime picks host rather than device function when `std::apply` is invoked from CUDA code.
 - `std::decay_t` -> `::std::decay_t` and `std::move`->`::std::move` as VC++ for some reason claims that `std` symbol is ambigious
 - Disable use of `std::aligned_alloc` on Android, as its `libc++` does not implement it.

Some prerequisites:
 - pytorch#89297
 - pytorch#89605
 - pytorch#90228
 - pytorch#90389
 - pytorch#90379
 - pytorch#89570
 - facebookincubator/gloo#336
 - facebookincubator/gloo#343
 - pytorch/builder@919676f

Fixes pytorch#56055

Pull Request resolved: pytorch#85969
Approved by: https://github.com/ezyang, https://github.com/kulinseth
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue Dec 20, 2022
With CUDA-10.2 gone we can finally do it!

This PR mostly contains build system related changes, invasive functional ones are to be followed.
Among many expected tweaks to the build system, here are few unexpected ones:
 - Force onnx_proto project to be updated to C++17 to avoid `duplicate symbols` error when compiled by gcc-7.5.0, as storage rule for `constexpr` changed in C++17, but gcc does not seem to follow it
 - Do not use `std::apply` on CUDA but rely on the built-in variant, as it results in test failures when CUDA runtime picks host rather than device function when `std::apply` is invoked from CUDA code.
 - `std::decay_t` -> `::std::decay_t` and `std::move`->`::std::move` as VC++ for some reason claims that `std` symbol is ambigious
 - Disable use of `std::aligned_alloc` on Android, as its `libc++` does not implement it.

Some prerequisites:
 - pytorch#89297
 - pytorch#89605
 - pytorch#90228
 - pytorch#90389
 - pytorch#90379
 - pytorch#89570
 - facebookincubator/gloo#336
 - facebookincubator/gloo#343
 - pytorch/builder@919676f

Fixes pytorch#56055

Pull Request resolved: pytorch#85969
Approved by: https://github.com/ezyang, https://github.com/kulinseth
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue Jan 3, 2023
With CUDA-10.2 gone we can finally do it!

This PR mostly contains build system related changes, invasive functional ones are to be followed.
Among many expected tweaks to the build system, here are few unexpected ones:
 - Force onnx_proto project to be updated to C++17 to avoid `duplicate symbols` error when compiled by gcc-7.5.0, as storage rule for `constexpr` changed in C++17, but gcc does not seem to follow it
 - Do not use `std::apply` on CUDA but rely on the built-in variant, as it results in test failures when CUDA runtime picks host rather than device function when `std::apply` is invoked from CUDA code.
 - `std::decay_t` -> `::std::decay_t` and `std::move`->`::std::move` as VC++ for some reason claims that `std` symbol is ambigious
 - Disable use of `std::aligned_alloc` on Android, as its `libc++` does not implement it.

Some prerequisites:
 - pytorch#89297
 - pytorch#89605
 - pytorch#90228
 - pytorch#90389
 - pytorch#90379
 - pytorch#89570
 - facebookincubator/gloo#336
 - facebookincubator/gloo#343
 - pytorch/builder@919676f

Fixes pytorch#56055

Pull Request resolved: pytorch#85969
Approved by: https://github.com/ezyang, https://github.com/kulinseth
@Skylion007
Copy link
Collaborator

I just opened a new issue, but was any though given to upgrading the C standard to C17 as well?

@ezyang
Copy link
Contributor

ezyang commented Feb 13, 2023

uhh, do we have that much C code? lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-engineering Relatively self-contained tasks for better engineering contributors module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.