-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
cuDNN v8 introduces a completely new API, and the old v7 API will no longer be supported at some point in the future. PyTorch needs to adopt this new API, which requires a complete rewriting of PyTorch's cuDNN convolution binding.
The new cuDNN v8 implementation of convolution will be based on cudnn-frontend, which is the officially recommended way to use cuDNN v8.
To make it easy to write and review, this adoption will be incremental. Below is a roadmap about how we are planning to do the work:
Stage 0: build a sketch
The purpose of this stage is to start the adoption process and add some basic infrastructure to PyTorch so that future work can be done incrementally and in parallel.
In #51390 and #50827, I have already refactored our convolution bindings into separate files ConvPlaceholders.cpp
, ConvShared.cpp
, ConvShared.h
, Conv_v7.cpp
, Conv_v8.cpp
. Future work will be mostly on Conv_v8.cpp
.
In #51390 a new build flag called USE_EXPERIMENTAL_CUDNN_V8_API
is added. When PyTorch is built with USE_EXPERIMENTAL_CUDNN_V8_API=1
, the convolution forward and transposed convolution backward will use a basic implementation of cuDNN v8 convolution API. This basic implementation provides most features of convolution forward, except the cuDNN benchmark. This implementation doesn't mean to be fast or ready to use. Some basic correctness check with unit tests is done to make sure there are no obvious mistakes, but we won't run any benchmark. We won't run any correctness check in any real model either. Some non-obvious unit test failures are allowed.
At this stage, we don't recommend any user to use USE_EXPERIMENTAL_CUDNN_V8_API=1
.
Stage 1: feature complete
The purpose of this stage is to have a complete implementation. At end of this stage, the v8 implementation should support all features that PyTorch supports and the microbenchmark should be comparable or faster compared to the v7 API. All unit tests should pass when PyTorch is built with USE_EXPERIMENTAL_CUDNN_V8_API=1
. We will also add a new CircleCI pipeline to test PyTorch with USE_EXPERIMENTAL_CUDNN_V8_API=1
.
Here is a list of features to implement based on #51390:
- cuDNN benchmark ([cuDNN v8] Extend current cuDNN convolution v8 API binding to support cuDNN benchmark #58859)
- convolution backward / transposed convolution forward ([cuDNN v8] Extend current cuDNN v8 API binding to support convolution backward and transposed convolution forward #58858)
- conv-bias-activation fusion ([cuDNN v8] Extend current cuDNN convolution v8 API binding to support conv-bias-activation fusion #58860)
- better error message with debugging information and python repro when failing (on par with v7, reference PR Print tensor shapes and convolution parameters when cuDNN exception is thrown #45023) ([cuDNN v8] Improve cuDNN convolution v8 API error reporting #58862)
- BFloat16 support ([cuDNN v8] Extend current cuDNN convolution v8 API binding to support BFloat16 #58861)
Here is a list of issues to resolve based on #51390:
- The heuristic/benchmark cache of engine config is not thread-safe.
- NHWC/NDHWC tests are failing.
- Use
uintptr_t
to compute alignment - Code is still not well organized. For example, in
ATen/cudnn/Descriptors.h
there are till lots of v7 only things.
At the end of this stage, interested early users might want to try USE_EXPERIMENTAL_CUDNN_V8_API=1
, but v8 API support is still not considered ready to use because the validation is still limited to unit tests. Test on real applications is required before consider enabling USE_EXPERIMENTAL_CUDNN_V8_API
by default.
Stage 2: thorough testing
We will enable USE_EXPERIMENTAL_CUDNN_V8_API=1
internally to run PyTorch with cuDNN v8 API with real models, to make sure there is no performance or functional issues. We will also seek further improvements in performance.
Stage 3: enable by default [WE ARE HERE]
At this stage, the flag USE_EXPERIMENTAL_CUDNN_V8_API
will be removed from PyTorch. PyTorch will use cuDNN v8 API whenever available.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @csarofeen @ptrblck @xwang233 @ngimel