diff --git a/.circleci/cimodel/data/simple/util/versions.py b/.circleci/cimodel/data/simple/util/versions.py index 3c9186df13aa..53d3a837248c 100644 --- a/.circleci/cimodel/data/simple/util/versions.py +++ b/.circleci/cimodel/data/simple/util/versions.py @@ -29,3 +29,6 @@ def __init__(self, major, minor): self.minor = minor super().__init__([self.major, self.minor], "cuda") + + def __str__(self): + return f"{self.major}.{self.minor}" diff --git a/.circleci/cimodel/data/windows_build_definitions.py b/.circleci/cimodel/data/windows_build_definitions.py index dea78411addb..c0e828eaab5e 100644 --- a/.circleci/cimodel/data/windows_build_definitions.py +++ b/.circleci/cimodel/data/windows_build_definitions.py @@ -86,10 +86,11 @@ def gen_tree(self): props_dict["executor"] = "windows-with-nvidia-gpu" props_dict["cuda_version"] = ( - miniutils.quote(str(self.cuda_version.major)) + miniutils.quote(str(self.cuda_version)) if self.cuda_version else "cpu" ) + props_dict["name"] = "_".join(name_parts) return [{key_name: props_dict}] diff --git a/.circleci/config.yml b/.circleci/config.yml index 8bdfb3c9c7bd..cdd66830986f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -325,7 +325,7 @@ pytorch_windows_params: &pytorch_windows_params default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -675,7 +675,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -737,7 +737,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -8077,7 +8077,7 @@ workflows: - postnightly - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" name: pytorch_windows_vs2019_py36_cuda10.1_build python_version: "3.6" use_cuda: "1" @@ -8086,7 +8086,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" executor: windows-with-nvidia-gpu name: pytorch_windows_vs2019_py36_cuda10.1_test1 python_version: "3.6" @@ -8099,7 +8099,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" executor: windows-with-nvidia-gpu name: pytorch_windows_vs2019_py36_cuda10.1_test2 python_version: "3.6" @@ -8112,7 +8112,7 @@ workflows: vc_year: "2019" - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" name: pytorch_windows_vs2019_py36_cuda11.1_build python_version: "3.6" use_cuda: "1" @@ -8121,7 +8121,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" executor: windows-with-nvidia-gpu filters: branches: @@ -8140,7 +8140,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" executor: windows-with-nvidia-gpu filters: branches: @@ -8204,7 +8204,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" filters: branches: only: diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index b530521f7f2d..f1022e113fa4 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -34,7 +34,13 @@ touch version.txt echo $(date +%s) > version.txt zip -r ${ZIPFILE} install src version.txt LICENSE # upload to aws -brew install awscli +# Install conda then 'conda install' awscli +curl --retry 3 -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +chmod +x ~/conda.sh +/bin/bash ~/conda.sh -b -p ~/anaconda +export PATH="~/anaconda/bin:${PATH}" +source ~/anaconda/bin/activate +conda install -c conda-forge awscli --yes set +x export AWS_ACCESS_KEY_ID=${AWS_S3_ACCESS_KEY_FOR_PYTORCH_BINARY_UPLOAD} export AWS_SECRET_ACCESS_KEY=${AWS_S3_ACCESS_SECRET_FOR_PYTORCH_BINARY_UPLOAD} diff --git a/.circleci/scripts/windows_cuda_install.sh b/.circleci/scripts/windows_cuda_install.sh index 8d615b674aa0..04a4c2ed43ff 100644 --- a/.circleci/scripts/windows_cuda_install.sh +++ b/.circleci/scripts/windows_cuda_install.sh @@ -1,13 +1,11 @@ #!/bin/bash set -eux -o pipefail -if [[ "$CUDA_VERSION" == "10" ]]; then - cuda_complete_version="10.1" +if [[ "$CUDA_VERSION" =~ ^10.* ]]; then cuda_installer_name="cuda_10.1.243_426.00_win10" msbuild_project_dir="CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions" cuda_install_packages="nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -elif [[ "$CUDA_VERSION" == "11" ]]; then - cuda_complete_version="11.1" +elif [[ "$CUDA_VERSION" =~ ^11.* ]]; then cuda_installer_name="cuda_11.1.0_456.43_win10" msbuild_project_dir="visual_studio_integration/CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions" cuda_install_packages="nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" @@ -16,7 +14,7 @@ else exit 1 fi -if [[ "${CUDA_VERSION}" != "10" && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then +if [[ "$CUDA_VERSION" =~ ^11.* && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then cuda_install_packages="${cuda_install_packages} Display.Driver" fi @@ -48,7 +46,7 @@ then export NVTOOLSEXT_PATH="C:\\Program Files\\NVIDIA Corporation\\NvToolsExt\\" fi -if ! ls "/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${cuda_complete_version}/bin/nvcc.exe" +if ! ls "/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${CUDA_VERSION}/bin/nvcc.exe" then echo "CUDA installation failed" mkdir -p /c/w/build-results diff --git a/.circleci/scripts/windows_cudnn_install.sh b/.circleci/scripts/windows_cudnn_install.sh index 529710af79b2..62f54615677e 100644 --- a/.circleci/scripts/windows_cudnn_install.sh +++ b/.circleci/scripts/windows_cudnn_install.sh @@ -1,12 +1,10 @@ #!/bin/bash set -eux -o pipefail -if [[ "$CUDA_VERSION" == "10" ]]; then - cuda_complete_version="10.1" - cudnn_installer_name="cudnn-10.1-windows10-x64-v7.6.4.38" -elif [[ "$CUDA_VERSION" == "11" ]]; then - cuda_complete_version="11.1" - cudnn_installer_name="cudnn-11.1-windows-x64-v8.0.5.39" +if [[ "$CUDA_VERSION" =~ ^10.* ]]; then + cudnn_installer_name="cudnn-${CUDA_VERSION}-windows10-x64-v7.6.4.38" +elif [[ "$CUDA_VERSION" =~ ^11.* ]]; then + cudnn_installer_name="cudnn-${CUDA_VERSION}-windows-x64-v8.0.5.39" else echo "CUDNN for CUDA_VERSION $CUDA_VERSION is not supported yet" exit 1 @@ -16,6 +14,6 @@ cudnn_installer_link="https://ossci-windows.s3.amazonaws.com/${cudnn_installer_n curl --retry 3 -O $cudnn_installer_link 7z x ${cudnn_installer_name}.zip -ocudnn -cp -r cudnn/cuda/* "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${cuda_complete_version}/" +cp -r cudnn/cuda/* "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${CUDA_VERSION}/" rm -rf cudnn rm -f ${cudnn_installer_name}.zip diff --git a/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml b/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml index e031e01ba846..c912a4fb690b 100644 --- a/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml +++ b/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml @@ -59,7 +59,7 @@ pytorch_windows_params: &pytorch_windows_params default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" diff --git a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml index 8d8036ea9523..aa0e2d2c5581 100644 --- a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml @@ -237,7 +237,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -299,7 +299,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4a0fb9cbf819..b04e4c82f57e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -75,7 +75,7 @@ jobs: - name: Run flake8 run: | set -eux - pip install flake8==3.8.2 flake8-bugbear==20.1.4 flake8-comprehensions==3.3.0 flake8-executable==2.0.4 flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 + pip install -r requirements-flake8.txt flake8 --version flake8 | tee ${GITHUB_WORKSPACE}/flake8-output.txt - name: Add annotations diff --git a/.jenkins/caffe2/build.sh b/.jenkins/caffe2/build.sh index 56ce8d525f89..0a4d1166bd05 100755 --- a/.jenkins/caffe2/build.sh +++ b/.jenkins/caffe2/build.sh @@ -18,49 +18,6 @@ build_to_cmake () { SCCACHE="$(which sccache)" -if [ "$(which gcc)" != "/root/sccache/gcc" ]; then - # Setup SCCACHE - ############################################################################### - # Setup sccache if SCCACHE_BUCKET is set - if [ -n "${SCCACHE_BUCKET}" ]; then - mkdir -p ./sccache - - SCCACHE="$(which sccache)" - if [ -z "${SCCACHE}" ]; then - echo "Unable to find sccache..." - exit 1 - fi - - # Setup wrapper scripts - wrapped="cc c++ gcc g++ x86_64-linux-gnu-gcc" - if [[ "${BUILD_ENVIRONMENT}" == *-cuda* ]]; then - wrapped="$wrapped nvcc" - fi - for compiler in $wrapped; do - ( - echo "#!/bin/sh" - - # TODO: if/when sccache gains native support for an - # SCCACHE_DISABLE flag analogous to ccache's CCACHE_DISABLE, - # this can be removed. Alternatively, this can be removed when - # https://github.com/pytorch/pytorch/issues/13362 is fixed. - # - # NOTE: carefully quoted - we want `which compiler` to be - # resolved as we execute the script, but SCCACHE_DISABLE and - # $@ to be evaluated when we execute the script - echo 'test $SCCACHE_DISABLE && exec '"$(which $compiler)"' "$@"' - - echo "exec $SCCACHE $(which $compiler) \"\$@\"" - ) > "./sccache/$compiler" - chmod +x "./sccache/$compiler" - done - - export CACHE_WRAPPER_DIR="$PWD/sccache" - - # CMake must find these wrapper scripts - export PATH="$CACHE_WRAPPER_DIR:$PATH" - fi -fi # Setup ccache if configured to use it (and not sccache) if [ -z "${SCCACHE}" ] && which ccache > /dev/null; then diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index e14828dc5afd..55b63d2144d0 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -1,5 +1,7 @@ #!/bin/bash +set -ex + # Required environment variable: $BUILD_ENVIRONMENT # (This is set by default in the Docker images we build, so you don't # need to set it yourself. @@ -7,13 +9,6 @@ # shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" -# Temp: use new sccache -if [[ -n "$IN_CI" && "$BUILD_ENVIRONMENT" == *rocm* ]]; then - # Download customized sccache - sudo curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache - sudo chmod 755 /opt/cache/bin/sccache -fi - source "$(dirname "${BASH_SOURCE[0]}")/common.sh" if [[ "$BUILD_ENVIRONMENT" == *-linux-xenial-py3-clang5-asan* ]]; then @@ -124,32 +119,6 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then export MAX_JOBS=$(($(nproc) - 1)) fi - # ROCm CI is using Caffe2 docker images, which needs these wrapper - # scripts to correctly use sccache. - if [[ -n "${SCCACHE_BUCKET}" && -z "$IN_CI" ]]; then - mkdir -p ./sccache - - SCCACHE="$(which sccache)" - if [ -z "${SCCACHE}" ]; then - echo "Unable to find sccache..." - exit 1 - fi - - # Setup wrapper scripts - for compiler in cc c++ gcc g++ clang clang++; do - ( - echo "#!/bin/sh" - echo "exec $SCCACHE $(which $compiler) \"\$@\"" - ) > "./sccache/$compiler" - chmod +x "./sccache/$compiler" - done - - export CACHE_WRAPPER_DIR="$PWD/sccache" - - # CMake must find these wrapper scripts - export PATH="$CACHE_WRAPPER_DIR:$PATH" - fi - if [[ -n "$IN_CI" ]]; then # Set ROCM_ARCH to gfx900 and gfx906 for CI builds echo "Limiting PYTORCH_ROCM_ARCH to gfx90[06] for CI builds" diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh index 0f015df045c2..17e7e9fa3445 100755 --- a/.jenkins/pytorch/codegen-test.sh +++ b/.jenkins/pytorch/codegen-test.sh @@ -37,7 +37,8 @@ python -m tools.setup_helpers.generate_code \ mkdir -p "$OUT"/pyi/torch/_C mkdir -p "$OUT"/pyi/torch/nn python -m tools.pyi.gen_pyi \ - --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ + --native-functions-path aten/src/ATen/native/native_functions.yaml \ + --deprecated-functions-path tools/autograd/deprecated.yaml \ --out "$OUT"/pyi # autograd codegen (called by torch codegen but can run independently) diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index 9a2c486610c4..fdf3c03e7f67 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -21,4 +21,5 @@ time python test/run_test.py --verbose -i distributed/test_jit_c10d time python test/run_test.py --verbose -i distributed/test_distributed_fork time python test/run_test.py --verbose -i distributed/test_c10d time python test/run_test.py --verbose -i distributed/test_c10d_spawn +time python test/run_test.py --verbose -i distributed/rpc/test_tensorpipe_agent assert_git_not_dirty diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 1f1f174e992e..8e9afd5c9bc3 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -11,6 +11,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" +export LANG=C.UTF-8 + if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then export PYTORCH_TEST_WITH_SLOW=1 export PYTORCH_TEST_SKIP_FAST=1 diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index f41e5f7fcd1b..7165f75a0e41 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -37,33 +37,19 @@ if "%VC_VERSION%" == "" ( @echo on popd -if "%CUDA_VERSION%" == "9" goto cuda_build_9 -if "%CUDA_VERSION%" == "10" goto cuda_build_10 -if "%CUDA_VERSION%" == "11" goto cuda_build_11 -goto cuda_build_end +if not "%USE_CUDA%"=="1" goto cuda_build_end -:cuda_build_9 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION% -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2 -set CUDA_PATH_V9_2=%CUDA_PATH% +rem version transformer, for example 10.1 to 10_1. +set VERSION_SUFFIX=%CUDA_VERSION:.=_% +set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH% -goto cuda_build_common - -:cuda_build_10 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1 -set CUDA_PATH_V10_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_11 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1 -set CUDA_PATH_V11_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_common +set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 +set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% +set CUDNN_ROOT_DIR=%CUDA_PATH% +set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt +set PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH% set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat index d4821c1b1a8d..ab102a0ea423 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat @@ -1,9 +1,9 @@ -if "%CUDA_VERSION%" == "9" set CUDA_SUFFIX=cuda92 -if "%CUDA_VERSION%" == "10" set CUDA_SUFFIX=cuda101 -if "%CUDA_VERSION%" == "11" set CUDA_SUFFIX=cuda110 +rem remove dot in cuda_version, fox example 11.1 to 111 +set VERSION_SUFFIX=%CUDA_VERSION:.=% +set CUDA_SUFFIX=cuda%VERSION_SUFFIX% if "%CUDA_SUFFIX%" == "" ( - echo unknown CUDA version, please set `CUDA_VERSION` to 9, 10 or 11. + echo unknown CUDA version, please set `CUDA_VERSION` higher than 9.2 exit /b 1 ) diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat index e3625ae75e9e..a052a1b67d59 100644 --- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -46,33 +46,13 @@ if %errorlevel% neq 0 ( exit /b %errorlevel% ) set DISTUTILS_USE_SDK=1 -if "%CUDA_VERSION%" == "9" goto cuda_build_9 -if "%CUDA_VERSION%" == "10" goto cuda_build_10 -if "%CUDA_VERSION%" == "11" goto cuda_build_11 -goto cuda_build_end +if not "%USE_CUDA%"=="1" goto cuda_build_end -:cuda_build_9 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION% -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2 -set CUDA_PATH_V9_2=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_10 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1 -set CUDA_PATH_V10_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_11 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1 -set CUDA_PATH_V11_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_common +rem version transformer, for example 10.1 to 10_1. +set VERSION_SUFFIX=%CUDA_VERSION:.=_% +set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH% set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% diff --git a/.travis.aten.yml b/.travis.aten.yml deleted file mode 100644 index 242584549625..000000000000 --- a/.travis.aten.yml +++ /dev/null @@ -1,31 +0,0 @@ -# https://travis-ci.org/zdevito/ATen -language: python -python: - - 2.7 - - 3.6 - -dist: trusty - -before_install: - - sudo apt-get install -qq valgrind - -install: - - travis_retry pip install pyyaml typing - -script: - - cd aten - - mkdir build install - - cd build - - cmake .. -DUSE_CUDA=OFF -DCMAKE_INSTALL_PREFIX=../install - - make install - - ../tools/run_tests.sh . - - cd .. - - tools/test_install.sh $(pwd)/install $(pwd) - -matrix: - fast_finish: true - include: - env: LINT_CHECK - python: "2.7" - install: pip install flake8-mypy - script: flake8 diff --git a/BUILD.bazel b/BUILD.bazel index 76afe6aec1ea..ec5111c5104d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -339,7 +339,8 @@ filegroup( "aten/src/ATen/cuda/CUDABlas.cpp", "aten/src/ATen/cuda/CUDASolver.cpp", "aten/src/ATen/cuda/CUDAContext.cpp", - "aten/src/ATen/cuda/CUDAGenerator.cpp", + "aten/src/ATen/cuda/CUDAGeneratorImpl.cpp", + "aten/src/ATen/cuda/CUDAGraph.cpp", "aten/src/ATen/cuda/CuSparseHandlePool.cpp", "aten/src/ATen/cuda/CublasHandlePool.cpp", "aten/src/ATen/cuda/CusolverDnHandlePool.cpp", @@ -544,6 +545,7 @@ header_template_rule( substitutions = { "@AT_MKLDNN_ENABLED@": "1", "@AT_MKL_ENABLED@": "0", + "@AT_FFTW_ENABLED@": "0", "@AT_NNPACK_ENABLED@": "0", "@CAFFE2_STATIC_LINK_CUDA_INT@": "0", "@USE_BLAS@": "1", diff --git a/CMakeLists.txt b/CMakeLists.txt index 62ea0a64d6c0..ba862b5a4d5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -684,8 +684,8 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") int main() { float a[] = {1.0, 1.0}; float32x4x2_t v; - v.val[0] = vcombine_f32 (vcreate_f32 (__AARCH64_UINT64_C (0)), vcreate_f32 (__AARCH64_UINT64_C (0))); - v.val[1] = vcombine_f32 (vcreate_f32 (__AARCH64_UINT64_C (0)), vcreate_f32 (__AARCH64_UINT64_C (0))); + v.val[0] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); + v.val[1] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); vst1q_f32_x2(a, v); return 0; }" HAS_VST1) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6593e35e4cf9..51e9d1382808 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -674,7 +674,7 @@ ccache -M 25Gi ``` To check this is working, do two clean builds of pytorch in a row. The second -build should be substantially and noticeably faster than the first build. +build should be substantially and noticeably faster than the first build. If this doesn't seem to be the case, check that each of the symlinks above actually link to your installation of `ccache`. For example, if you followed the first option and installed `ccache` from source on a Linux machine, running `readlink -e $(which g++)` should return `~/ccache/bin/ccache`. #### Use a faster linker @@ -891,7 +891,7 @@ which is in PyTorch's `requirements.txt`. ## Pre-commit tidy/linting hook We use clang-tidy and flake8 (installed with flake8-bugbear, -flake8-comprehensions, flake8-mypy, and flake8-pyi) to perform additional +flake8-comprehensions, flake8-pyi, and others) to perform additional formatting and semantic checking of code. We provide a pre-commit git hook for performing these checks, before a commit is created: diff --git a/README.md b/README.md index 195dffc09058..d29eacc28664 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_ex On Linux ```bash # Add LAPACK support for the GPU if needed -conda install -c pytorch magma-cuda102 # or [ magma-cuda101 | magma-cuda100 | magma-cuda92 ] depending on your cuda version +conda install -c pytorch magma-cuda110 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo ``` On MacOS diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp index e4bb4c083160..9cc71f117d93 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp @@ -90,13 +90,13 @@ class PytorchJni : public facebook::jni::HybridClass { #endif #ifdef TRACE_ENABLED - static bool onFunctionEnter( + static std::unique_ptr onFunctionEnter( const at::RecordFunction& fn) { Trace::beginSection(fn.name().str()); - return true; + return nullptr; } - static void onFunctionExit(const at::RecordFunction&) { + static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) { Trace::endSection(); } #endif diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle index 37bdb35e2f19..df7b758e3b31 100644 --- a/android/test_app/app/build.gradle +++ b/android/test_app/app/build.gradle @@ -60,20 +60,20 @@ android { //} flavorDimensions "model", "build", "activity" productFlavors { - mbq { + mnet { dimension "model" - applicationIdSuffix ".mbq" - buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"") - addManifestPlaceholders([APP_NAME: "MBQ"]) - buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"") + applicationIdSuffix ".mnet" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"") + addManifestPlaceholders([APP_NAME: "MNET"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"") } - mbvulkan { + mnetVulkan { dimension "model" - applicationIdSuffix ".mbvulkan" - buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"") + applicationIdSuffix ".mnet_vulkan" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"") buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true') - addManifestPlaceholders([APP_NAME: "MBQ"]) - buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"") + addManifestPlaceholders([APP_NAME: "MNET_VULKAN"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"") } resnet18 { dimension "model" diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 16470f39ad54..0731e87f52a2 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -233,6 +233,32 @@ Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) { return self_physical.newLogicalFromPhysical(result); } +Tensor& fill_inplace_scalar_batching_rule(Tensor& self, Scalar value) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().fill_(value); + return self; +} + +Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) { + auto value_batched = isBatchedTensor(value); + + if (value_batched) { + auto physical_args = + BroadcastingVmapTransform::logicalToPhysical({self, value}); + physical_args[0].tensor().copy_(physical_args[1].tensor()); + } else { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().fill_(value); + } + return self; +} + +Tensor& zero_inplace_batching_rule(Tensor &self) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().zero_(); + return self; +} + Tensor squeeze_batching_rule(const Tensor& self) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto physical_sizes = self_physical.tensor().sizes(); @@ -941,8 +967,8 @@ Tensor new_empty_strided_batching_rule( size.size(), ") must match dimensionality of strides (", stride.size(), ")"); auto storage_size = native::storage_size_for(size, stride); - for (int64_t idx = 0; idx < physical_strides.size(); ++idx) { - physical_strides[idx] *= storage_size; + for (auto& physical_stride : physical_strides) { + physical_stride *= storage_size; } // physical_strides = [B1 * B2 * S, B2 * S, S] + strides @@ -971,6 +997,11 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("is_complex", native::is_complex); m.impl("conj", native::conj); + // inplace operations + m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule); + m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule); + m.impl("zero_", zero_inplace_batching_rule); + // view operations m.impl("as_strided", as_strided_batching_rule); m.impl("chunk", chunk_batching_rule); diff --git a/aten/src/ATen/CUDAGeneratorImpl.h b/aten/src/ATen/CUDAGeneratorImpl.h index ec83128c7013..9a9febd01f8e 100644 --- a/aten/src/ATen/CUDAGeneratorImpl.h +++ b/aten/src/ATen/CUDAGeneratorImpl.h @@ -131,8 +131,8 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl { uint64_t seed() override; void set_philox_offset_per_thread(uint64_t offset); uint64_t philox_offset_per_thread(); - void graph_prologue(int64_t* offset_extragraph); - uint64_t graph_epilogue(); + void capture_prologue(int64_t* offset_extragraph); + uint64_t capture_epilogue(); PhiloxCudaState philox_cuda_state(uint64_t increment); // Temporarily accommodates call sites that use philox_engine_inputs. @@ -147,6 +147,7 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl { uint64_t philox_offset_per_thread_ = 0; int64_t* offset_extragraph_; uint32_t offset_intragraph_ = 0; + bool graph_expects_this_gen_ = false; }; namespace cuda { diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index 58c06c63535d..38326491bed8 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -8,6 +8,7 @@ #define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@ #define AT_MKL_ENABLED() @AT_MKL_ENABLED@ +#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@ #define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@ #define CAFFE2_STATIC_LINK_CUDA() @CAFFE2_STATIC_LINK_CUDA_INT@ #define AT_BUILD_WITH_BLAS() @USE_BLAS@ diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 1977f945a0fb..e17322e1681d 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -232,7 +233,7 @@ bool Context::setFlushDenormal(bool on) { } Allocator* getCPUAllocator() { - return getTHDefaultAllocator(); + return c10::GetCPUAllocator(); } // override_allow_tf32_flag = true diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 9f0c51166172..41252609953f 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -2,17 +2,59 @@ #include #include +#include #include #include #include +#include #include +#include + +namespace at { +/** + * The method should_include_kernel_dtype() returns true/false + * based on whether the switching code for a specific dtype should be + * included based on build time constants generated from tracing model + * execution. This method will be implmeneted via code-generation and + * included in this file when code-gen is ready. + */ +inline constexpr bool should_include_kernel_dtype( + const char *kernel_tag_str, + at::ScalarType scalar_type +) { + return true; +} +} + +/** + * In the Facebook internal build (using BUCK), this macro is enabled by + * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer + * binary. + */ +#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ + {RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::KERNEL_FUNCTION_DTYPE, \ + std::string(NAME) + "$" + toString(enum_type), \ + {});} +#else +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) +#endif -#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \ - case enum_type: { \ - using scalar_t = type; \ - return __VA_ARGS__(); \ +#define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ + case enum_type: { \ + at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \ + [&] { \ + AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \ + } \ + ); \ + using HINT = type; \ + return __VA_ARGS__(); \ } +#define AT_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, scalar_t, __VA_ARGS__) + // Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused // attribute in the type aliasing context. Keep name long and verbose to avoid // macro collisions. @@ -143,6 +185,21 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} // 4. Should complex be supported? The answer is almost always no, // unless you are working on "generic" code that should work on // all dtypes. +// +// Parameters: +// ----------- +// +// 1. The NAME argument is a "tag" that is used to trace and then +// conditionally compile fragments of the case statements such +// that the kernel functions are specialized only for the dtypes +// that are needed. The NAME parameter *must* be a build time +// cons char* (can't be std::string, etc...) +// +// Please ensure that the NAME is unique for every implementation +// or you run the risk of over-including code for the kernel +// functions. There is no risk of missing out on any code, so +// it's mostly a risk of a Type-2 error, and not a Type-1 error. +// // NB: the the_type variable is not used, but we have kept it for // backwards compatibility. It's probably not used by anyone though; @@ -154,26 +211,28 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ +#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }() #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ @@ -181,10 +240,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -199,14 +259,17 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -220,13 +283,20 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -238,14 +308,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -259,19 +333,28 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} [&] { \ const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -285,31 +368,36 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ +#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() @@ -318,17 +406,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }() #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ @@ -336,11 +425,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -351,6 +447,7 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_QINT_PRIVATE_CASE_TYPE( \ at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \ @@ -368,6 +465,7 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ @@ -387,17 +485,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op*/ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, \ at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ @@ -406,154 +505,196 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_ALL_TYPES_AND3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ +#define AT_DISPATCH_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE3, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE3, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() @@ -562,15 +703,10 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_index_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _it = ::detail::scalar_type(the_index_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _it) \ switch (_it) { \ - case at::ScalarType::Int: { \ - using index_t = int32_t; \ - return __VA_ARGS__(); \ - } \ - case at::ScalarType::Long: { \ - using index_t = int64_t; \ - return __VA_ARGS__(); \ - } \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Int, int32_t, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Long, int64_t, index_t, __VA_ARGS__)\ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \ } \ @@ -586,15 +722,16 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp index 5fcb5ede9cc5..a9198f7b2548 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp +++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp @@ -832,53 +832,6 @@ std::tuple _th_gels(const Tensor & self, const Tensor & A) { } return std::tuple(res1, res2); } -std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - case ScalarType::Float: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - default: - AT_ERROR("_th_eig_out not supported on CPUType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} -std::tuple _th_eig(const Tensor & self, bool eigenvectors) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto res1_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res1 = Tensor(c10::intrusive_ptr::reclaim(res1_)); - auto res2_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res2 = Tensor(c10::intrusive_ptr::reclaim(res2_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - default: - AT_ERROR("_th_eig not supported on CPUType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h index 1aca02539311..6e02db8075e2 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.h +++ b/aten/src/ATen/LegacyTHFunctionsCPU.h @@ -38,8 +38,6 @@ Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scala Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max); std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A); std::tuple _th_gels(const Tensor & self, const Tensor & A); -std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors); -std::tuple _th_eig(const Tensor & self, bool eigenvectors); Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper); Tensor _th_potri(const Tensor & self, bool upper); std::tuple _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self); diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 264271d35229..a9128e0e94ed 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -75,4 +75,16 @@ void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) { "Please clone() the tensor before performing the operation."); } +void assert_no_overlap(const Tensor& a, const Tensor& b) { + assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); +} + +void assert_no_overlap(TensorImpl* a, TensorImpl* b) { + const auto lap = get_overlap_status(a, b); + TORCH_CHECK(lap != MemOverlapStatus::PARTIAL && lap != MemOverlapStatus::FULL, + "unsupported operation: some elements of the input tensor and " + "the written-to tensor refer to a single memory location. " + "Please clone() the tensor before performing the operation."); +} + } diff --git a/aten/src/ATen/MemoryOverlap.h b/aten/src/ATen/MemoryOverlap.h index 67f63a64668c..5cd4eab2db9c 100644 --- a/aten/src/ATen/MemoryOverlap.h +++ b/aten/src/ATen/MemoryOverlap.h @@ -27,4 +27,7 @@ CAFFE2_API MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b); CAFFE2_API void assert_no_partial_overlap(const Tensor& a, const Tensor& b); void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b); +CAFFE2_API void assert_no_overlap(const Tensor& a, const Tensor& b); +CAFFE2_API void assert_no_overlap(TensorImpl* a, TensorImpl* b); + } diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp index 668838877123..5f8de486dc78 100644 --- a/aten/src/ATen/NamedTensorUtils.cpp +++ b/aten/src/ATen/NamedTensorUtils.cpp @@ -264,11 +264,11 @@ static std::vector compute_dot_product_outnames( } std::vector outnames(num_outnames, Dimname::wildcard()); int64_t index = 0; - for (int64_t j = 0; j < tensor_names.size(); ++j) { + for (size_t j = 0; j < tensor_names.size(); ++j) { if (j == tensor_dotted_dim) continue; outnames[index++] = tensor_names[j]; } - for (int64_t j = 0; j < other_names.size(); ++j) { + for (size_t j = 0; j < other_names.size(); ++j) { if (j == other_dotted_dim) continue; outnames[index++] = other_names[j]; } diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 45492d7b212e..8d7d4b2ce0f8 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -46,6 +46,8 @@ SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::Typ AT_ASSERT(values_.sizes() == IntArrayRef({0})); AT_ASSERT(values_.device() == indices_.device()); AT_ASSERT(values_.device() == device()); + + is_non_overlapping_and_dense_ = false; } IntArrayRef SparseTensorImpl::strides() const { diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index a2bdc24ff51c..162efd1c6c8a 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -4,6 +4,7 @@ #include #include #include +#include // TODO: try to remove this // There is some back story, see https://github.com/pytorch/pytorch/issues/48684 @@ -226,9 +227,9 @@ static inline Tensor applySelect( static inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty. if (value) { - return at::native::zeros({1}, {}, self.options().dtype(kLong)); + return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.); } else { - return at::native::empty({0}, {}, self.options().dtype(kLong)); + return at::empty({0}, {}, self.options().dtype(kLong)); } } @@ -249,10 +250,6 @@ static inline Tensor boolToIndexingTensor(const Tensor& self, bool value, const } } -static inline Tensor scalarToTensorCPUOrCUDA(Scalar v, const TensorOptions& options) { - return at::native::scalar_tensor(v, options); -} - static inline Tensor scalarToTensorNonNativeDeviceType(Scalar v, const TensorOptions& options) { return at::scalar_tensor(v, options); } @@ -320,8 +317,11 @@ static inline int64_t count_specified_dimensions(const ArrayRef& in // The rest of the functions are in `at::indexing::impl` namespace, signifying // that they shouldn't be used from Python indexing implementation. static inline Tensor scalarToTensor(Scalar v, const TensorOptions& options, const at::Device& self_device) { - if (self_device == at::kCPU || self_device == at::kCUDA) { - return impl::scalarToTensorCPUOrCUDA(v, options); + if (self_device == at::kCPU && !v.isComplex() && + options.dtype_opt()->toScalarType() != ScalarType::ComplexDouble && + options.dtype_opt()->toScalarType() != ScalarType::ComplexFloat && + options.dtype_opt()->toScalarType() != ScalarType::ComplexHalf) { + return at::detail::scalar_tensor_static(v, options.dtype_opt()->toScalarType(), self_device); } else { return impl::scalarToTensorNonNativeDeviceType(v, options); } diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 43acc9a070d5..3f5f9280eb99 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -402,14 +402,14 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) { // TODO: reuse temporaries when possible (e.g. for inplace operations) if (common_device == kCPU) { // Casts to outputs by creating temporaries of the correct dtype (if needed) - if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_) { + // NB: we skip this on is_meta_, because the temporary allocation here is + // unnecessary if we aren't going to actually do the compute + if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) { TORCH_INTERNAL_ASSERT(op.tensor.defined()); + // Marker [Output original_tensor is set] op.original_tensor = op.tensor; // NB: do NOT use set_output here, as the temporary is NOT a true output; // op.tensor is the true output and it was pre-provided for us. - // TODO: When we extend this to work with meta tensors, we'll need to - // skip this temporary allocation in that case (because it's - // unnecessary) // TODO: The logic for cast_outputs will need to be handled by the // structured kernels implementation. What probably should happen // is that we pass in the inferred dtype into the out kernel, and @@ -488,10 +488,10 @@ void TensorIteratorBase::allocate_or_resize_outputs() { set_output(i, tensor_shape, tensor_stride, op.options(), names_); } op.current_dtype = op.target_dtype; - } else if (op.tensor.defined() && !names_.empty()) { - // Even if we don't resize, we may still propagate names, esp - // if we were doing an inplace operation - namedinference::propagate_names(op.tensor, names_); + } else if (op.tensor.defined()) { + // Even if we don't resize, we still need to tell set_output about + // the output, so that we properly set guard and propagate names + set_output(i, op.tensor.sizes(), {}, op.tensor.options(), names_); } } } @@ -765,6 +765,8 @@ void TensorIteratorBase::cast_outputs() { for (auto& op : operands_) { if (op.is_output && op.original_tensor.defined() && op.original_tensor.scalar_type() != op.current_dtype) { + // TODO: Now that set_output resizes both the original_tensor + // and tensor, this condition should no longer ever be true if (op.original_tensor.sizes() != op.tensor.sizes()){ op.original_tensor.resize_as_(op.tensor).as_strided_(op.tensor.sizes(), op.tensor.strides()); } @@ -808,18 +810,22 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic } } -TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, - const Tensor& b) { - return TensorIteratorConfig() - .set_check_mem_overlap(true) - .add_output(out) - .add_input(a) - .add_input(b) - .allow_cpu_scalars(true) - .promote_inputs_to_common_dtype(true) - .cast_common_dtype_to_outputs(true) - .enforce_safe_casting_to_output(true) - .build(); +void TensorIteratorBase::build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) { + build(TensorIteratorConfig() + .set_check_mem_overlap(true) + .add_output(out) + .add_input(a) + .add_input(b) + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true)); +} + +TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) { + TensorIterator iter; + iter.build_binary_op(out, a, b); + return iter; } // Helper to construct a binary op that promotes integer inputs to float. @@ -939,8 +945,15 @@ TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tenso } void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) { - for (int i = 0; i < config.tensors_.size(); i++) { - operands_.emplace_back(std::move(config.tensors_[i])); + for (auto& tensor: config.tensors_) { + // If *any* of the arguments is a meta tensor, the overall + // computation is a meta computation (don't do any work, + // just compute output information). This aligns with + // our multiple dispatch semantics. + if (tensor.is_meta()) { + is_meta_ = true; + } + operands_.emplace_back(std::move(tensor)); } num_outputs_ = config.num_outputs_; } @@ -988,6 +1001,10 @@ void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config if (!config.check_mem_overlap_) { return; } + if (is_meta_) { + // We don't have pointer addresses, cannot check for overlap! + return; + } for (int i = 0; i < num_outputs_; i++) { const auto& output = operands_[i].tensor; if (!output.defined()) continue; @@ -1265,9 +1282,11 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) { // allocate the output tensor if it's not provided allocate_or_resize_outputs(); // coalesce adjacent dimensions when possible - coalesce_dimensions(); + if (!is_meta_) coalesce_dimensions(); } + if (is_meta_) return; + for (auto& op : operands_) { TORCH_INTERNAL_ASSERT(op.tensor.defined()); op.data = op.tensor.data_ptr(); @@ -1281,14 +1300,92 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) { view_offsets_ = DimVector(ndim_offsets, 0); } +// This is the structured kernels implementation of set_output. It is +// NEVER actually called directly; instead, a subclass of TensorIteratorBase +// will override set_output to actually do the operation, and then call +// set_output on the TensorIteratorBase to setup TI's metadata. +// The precondition for this function is that maybe_get_output() now +// unconditionally returns a real Tensor (prior to output setting, +// this function may return an undefined tensor.) +void TensorIteratorBase::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) { + auto& op = operands_[output_idx]; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); + const auto& t = maybe_get_output(output_idx); + TORCH_INTERNAL_ASSERT(t.defined()); + if (!op.tensor.defined()) { + op.tensor = t; + op.current_dtype = op.target_dtype; + } else if (op.will_resize) { + if (op.original_tensor.defined()) { + // OK, so this is pretty weird. To understand how we can end up in + // this situation, first look at Marker [Output original_tensor is set]. + // That is the sole site where original_tensor may be set on an + // output operand. Essentially, when we are given an explicit output + // tensor whose dtype doesn't match the computed common dtype from + // the input operands, we do a switcheroo: we replace the (incorrectly + // typed) output tensor with a correctly typed, *temporary* tensor, + // and remember the original tensor in original_tensor (which will + // then get written back to when we cast_outputs). + // + // Now, what if the given output tensor also happened to be zero + // size (meaning that we will_resize it)? Well, at the call site + // above, we don't necessarily(*) know what the correct shape should + // be, so we give the temporary tensor the same shape as the original. + // At the time of set_output is when we DO know what the correct size + // is, and the subclass's implementation of set_output in structured class + // responsible for resizing original_tensor. But we still have this + // incorrectly sized temporary output which the structured subclass + // knows nothing about, so we are obligated to also resize it here. + // + // This is a slight memory pessimization, because previously + // original_tensor only got resized at the end of the computation, rather + // than at the beginning (as happens here). However, the peak memory + // usage is the same, since you need to materialize both original tensor + // and temporary tensor to do the copy. + // + // (*) Actually, technically, we probably do know what the shape + // should be, since we do shape computation before dtype computation. + // So hypothetically we could figure out what the correct shape is + // at that point in time and directly allocate the temporary at + // the right size. + // + // But a better solution is to delay allocation of temporaries until + // after TensorIterator builder, waiting until we actually want + // to do the computation. That would also remove the necessity + // for the is_meta_ test. + TORCH_INTERNAL_ASSERT(op.original_tensor.is_same(t)); + TORCH_INTERNAL_ASSERT(!op.tensor.is_same(t)); + at::native::resize_output(op.tensor, sizes); + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + op.tensor.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + op.tensor.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } + } +} + +// This is the "traditional" implementation of set_output. On TensorIterator +// instances, it is invoked directly from various call sites in this file. No +// funny business. void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) { + // NB: intentionally no superclass call auto& op = operands_[output_idx]; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); if (!op.tensor.defined()) { if (strides.empty()) { - op.tensor = at::empty(sizes, options); + if (is_meta_) { + op.tensor = at::empty_meta(sizes, options); + } else { + op.tensor = at::empty(sizes, options); + } } else { - op.tensor = at::empty_strided(sizes, strides, options); + if (is_meta_) { + TORCH_INTERNAL_ASSERT(0, "meta strided not yet implemented"); + } else { + op.tensor = at::empty_strided(sizes, strides, options); + } } op.current_dtype = op.target_dtype; } else if (op.will_resize) { @@ -1306,6 +1403,14 @@ void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayR } } +// Not actually used by anything (TensorIterator subclass calls +// its own implementation of set_output which knows exactly where +// all the outputs are), but we have to provide all pure virtual methods +// for MetaBase +const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) { + return operands_[output_idx].tensor; +} + SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const { return SplitUntil32Bit(*this); } diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index 11dbda5c7959..ba781d7501e6 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -297,6 +297,10 @@ struct CAFFE2_API TensorIteratorBase : public impl::MetaBase { return true; } + void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override; + + void build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b); + protected: // Mutable reference as it moves tensors out of TensorIteratorConfig void populate_operands(TensorIteratorConfig&); @@ -399,6 +403,9 @@ struct CAFFE2_API TensorIteratorBase : public impl::MetaBase { // From TensorIteratorConfig bool is_reduction_ = false; + + /// Set by populate_operands(), says if we're handling meta tensors + bool is_meta_ = false; }; struct CAFFE2_API TensorIterator final : public TensorIteratorBase { @@ -415,6 +422,7 @@ struct CAFFE2_API TensorIterator final : public TensorIteratorBase { static TensorIterator reduce_op(Tensor& out, const Tensor& a); static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a); + const Tensor& maybe_get_output(int64_t output_idx) override; void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override; }; diff --git a/aten/src/ATen/TensorMeta.cpp b/aten/src/ATen/TensorMeta.cpp index 30dca8ccaf2e..6f4d667d5653 100644 --- a/aten/src/ATen/TensorMeta.cpp +++ b/aten/src/ATen/TensorMeta.cpp @@ -1,21 +1,5 @@ #include -#include namespace at { -Tensor meta_tensor_from_meta(const TensorMeta& meta) { - // TODO: eliminate indirection - return at::empty_meta(meta.sizes, meta.options); -} - -Tensor tensor_from_meta(const TensorMeta& meta) { - // TODO: eliminate indirection - return at::empty(meta.sizes, meta.options); -} - -// Analogous to self.new_empty(sizes) -TensorMeta new_meta(const Tensor& self, IntArrayRef sizes) { - return TensorMeta(sizes, self.options()); -} - } // namespace at diff --git a/aten/src/ATen/TensorMeta.h b/aten/src/ATen/TensorMeta.h index baa6e6112b34..134bb373e3b2 100644 --- a/aten/src/ATen/TensorMeta.h +++ b/aten/src/ATen/TensorMeta.h @@ -10,28 +10,54 @@ class Tensor; namespace impl { -struct MetaBase { +// Use this to define the prototype for a meta function. There are two +// versions; one that takes one argument (just the operator name), or FUNC2 +// variant that takes two arguments (operator name and overload name). +// +// Example usage: +// +// TORCH_META_FUNC2(add, Tensor) ( +// const Tensor& self, const Tensor& other +// ) { +// ... compute sizes and options ... +// set_output(sizes, options); +// } +// +#define TORCH_META_FUNC(name) void name::meta +#define TORCH_META_FUNC2(name, overload) void name##_##overload::meta + +// Use this to define the prototype for an implementation. This takes only +// one argument, which is the name of the dispatch key entry you're +// implementing. +// +// Example usage: +// +// TORCH_IMPL_FUNC(add_cpu) ( +// Tensor& result, const Tensor& self, const Tensor& other +// ) { +// ... do the actual implementation ... +// } +// +#define TORCH_IMPL_FUNC(name) void structured_##name::impl + +// Base class for all structured kernel classes. The set_output virtual +// method is varied depending whether or not the operator is +// functional/out/inplace, and could also be specialized for CPU/CUDA/etc +// (although presently it isn't). +// +// A notable subclass of this interface is TensorIteratorBase. +struct CAFFE2_API MetaBase { virtual void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) = 0; + virtual const Tensor& maybe_get_output(int64_t output_idx) = 0; void set_output(IntArrayRef sizes, TensorOptions options) { set_output(0, sizes, {}, options, {}); } + // Returns a reference to an undefined tensor if there is no presupplied + // output + const Tensor& maybe_get_output() { return maybe_get_output(0); } virtual ~MetaBase() {} }; } // namespace impl -struct TensorMeta { - DimVector sizes; - // TODO: DimVector strides; - TensorOptions options; - - TensorMeta(IntArrayRef _sizes, TensorOptions _options) - : sizes(_sizes), options(_options) {} -}; - -CAFFE2_API Tensor meta_tensor_from_meta(const TensorMeta& meta); -CAFFE2_API Tensor tensor_from_meta(const TensorMeta& meta); -// Analogous to self.new_empty(sizes) -CAFFE2_API TensorMeta new_meta(const Tensor& self, IntArrayRef sizes); - } // namespace at diff --git a/aten/src/ATen/TensorNames.cpp b/aten/src/ATen/TensorNames.cpp index 844ff4ba2bad..a7dc0bd68036 100644 --- a/aten/src/ATen/TensorNames.cpp +++ b/aten/src/ATen/TensorNames.cpp @@ -61,10 +61,10 @@ TensorNames::TensorNames(ArrayRef names, int64_t start, int64_t end) { } TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) { - int64_t size_diff = std::labs(names_.size() - other.names_.size()); + size_t size_diff = std::labs(names_.size() - other.names_.size()); if (names_.size() > other.names_.size()) { - for (int64_t idx = size_diff; idx < names_.size(); ++idx) { + for (size_t idx = size_diff; idx < names_.size(); ++idx) { names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name); } } else { diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 6d74e2f47ce0..3c7b9b6ff5bc 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -19,6 +19,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode) grad_mode_enabled_ = GradMode::is_enabled(); } #endif + bumped_record_all_functions_ = at::checkRecordAllFunctions(); } /* static */ diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index f0cb85f0ff84..3c9b55b3d8d6 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -38,6 +38,9 @@ class TORCH_API ThreadLocalState { bool grad_mode_enabled_; #endif + // Whether pre-sampling RecordFunction optimization was enabled + bool bumped_record_all_functions_ = false; + friend class ThreadLocalStateGuard; }; @@ -45,7 +48,21 @@ class TORCH_API ThreadLocalState { class TORCH_API ThreadLocalStateGuard { public: explicit ThreadLocalStateGuard(const ThreadLocalState& state) - : prev_state_(ThreadLocalState()) { + : prev_state_(ThreadLocalState()), + bumped_record_all_functions_(state.bumped_record_all_functions_) { + // Special handling of RecordFunction pre-sampling optimization: + // pre-samping is enabled (bumped) when there're non-sampled + // (or high-frequency) global or TLS callbacks. + // + // ThreadLocalStateGuard simply resets RecordFunction's TLS and + // hence its thread local callbacks. + // + // Checking if the pre-sampling was enabled and preserving it in the + // async task by calling bumpRecordAllFunctions() and the corresponding + // releaseRecordAllFunctions() + if (bumped_record_all_functions_) { + at::bumpRecordAllFunctions(); + } // set the given state across the thread boundary ThreadLocalState::setThreadLocalState(state); } @@ -53,10 +70,15 @@ class TORCH_API ThreadLocalStateGuard { ~ThreadLocalStateGuard() { // restore previously set variables ThreadLocalState::setThreadLocalState(prev_state_); + if (bumped_record_all_functions_) { + at::releaseRecordAllFunctions(); + } } private: const ThreadLocalState prev_state_; + // Whether pre-sampling RecordFunction optimization was enabled + bool bumped_record_all_functions_ = false; }; template diff --git a/aten/src/ATen/Utils.cpp b/aten/src/ATen/Utils.cpp index a2e5a82c5d06..26fc7dabfd73 100644 --- a/aten/src/ATen/Utils.cpp +++ b/aten/src/ATen/Utils.cpp @@ -57,8 +57,12 @@ Tensor empty_cpu( tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); + if (memory_format_opt.has_value()) { + // Restriding a just-created empty contiguous tensor does nothing. + if (*memory_format_opt != MemoryFormat::Contiguous) { + tensor.unsafeGetTensorImpl()->empty_tensor_restride(*memory_format_opt); + } + } return tensor; } diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 192e131897c8..6b9561767a5f 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -97,6 +97,14 @@ std::string used_cpu_capability() { ss << "CPU capability usage: "; auto capability = native::get_cpu_capability(); switch (capability) { +#ifdef HAVE_VSX_CPU_DEFINITION + case native::CPUCapability::DEFAULT: + ss << "DEFAULT"; + break; + case native::CPUCapability::VSX: + ss << "VSX"; + break; +#else case native::CPUCapability::DEFAULT: ss << "NO AVX"; break; @@ -106,6 +114,7 @@ std::string used_cpu_capability() { case native::CPUCapability::AVX2: ss << "AVX2"; break; +#endif default: break; } diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 3dfb4ee4f04b..3d040387d3bb 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -119,14 +119,6 @@ namespace impl { "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead."); }; - template - struct assert_is_valid_input_type, AllowDeprecatedTypes> - : assert_is_valid_input_type { - static_assert(!std::is_same::value, - "You tried to register a kernel with an unsupported input type: std::vector. Please use List, List or Tensor instead."); - // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector. Please use List instead."); - }; - template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 5184e8c5f698..5e3e91afbb45 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -134,13 +134,11 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin OperatorName op_name = schema.operator_name(); auto op = findOrRegisterName_(op_name); - if (op.operatorIterator_->def_count == 0) { - // NB: registerSchema is not idempotent! Only do it once! - op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug)); - listeners_->callOnOperatorRegistered(op); - } else { - checkSchemaCompatibility(op, schema, debug); - } + TORCH_CHECK(op.operatorIterator_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.", + " Each overload's schema should only be registered with a single call to def().", + " Duplicate registration: ", debug, ". Original registration: ", op.operatorIterator_->op.debug()); + op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug)); + listeners_->callOnOperatorRegistered(op); // NB: do not increment the counts until AFTER error checking ++op.operatorIterator_->def_count; @@ -151,25 +149,6 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin }); } -void Dispatcher::checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug) { - TORCH_CHECK(op.schema() == schema, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", schema, " (", debug, ") vs ", op.schema(), " (", op.debug(), ")"); - if (schema.isDefaultAliasAnalysisKind()) { - // [BACKWARDS COMPAT] If the *new* schema is the default alias analysis - // kind, for BC, we will accept it. If we don't accept it, most extensions - // that override existing operators will stop working (as they generally did - // not specify alias information). - } else if (op.schema().isDefaultAliasAnalysisKind()) { - // [BACKWARDS COMPAT] If you POST-FACTO specify a non-default alias analysis - // kind after we already have a schema for a function, bong it in for BC - // reasons. - op.operatorIterator_->op.updateSchemaAliasAnalysis(schema.aliasAnalysis()); - } else { - TORCH_CHECK(op.schema().aliasAnalysis() == schema.aliasAnalysis(), - "Tried to define the schema for ", toString(op.operator_name()), " with different alias analysis kinds: ", - toString(op.schema().aliasAnalysis()), " (", op.debug(), ") vs ", toString(schema.aliasAnalysis()), " (", debug, ")"); - } -} - void Dispatcher::deregisterDef_(const OperatorHandle& op, const OperatorName& op_name) { // we need a lock to avoid concurrent writes std::lock_guard lock(mutex_); diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 632739053c42..f83302e2d819 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -371,28 +371,39 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandleop.lookup(dispatchKey); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - // Check if we need to run callbacks registered with RecordFunction - // If true and callbacks need inputs, we box the arguments and pass - // them into the callbacks and also into the kernel call - - // Note: for perf reasons we wouldn't want to pass arguments into - // the function call or prematurely box them - at::RecordFunction guard(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(guard.isActive())) { - if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) { - int64_t seq_num = -1; - // Setting sequence number in the Autograd case to associate - // the forward range with the coresponding Autograd's node - if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { - seq_num = at::sequence_number::peek(); - } - if (guard.needsInputs()) { - torch::jit::Stack stack = impl::boxArgs(args...); - guard.before(op, stack, seq_num); - } else { - guard.before(op, seq_num); + // By default, when there're no high-frequency or non-sampled callbacks, + // RecordFunction is pre-sampled as a perf optimization; + // shouldRunRecordFunction checks whether RecordFunction should be executed, + // and sets pre_sampled boolean argument value to whether pre-sampling was used - + // this boolean is passed into RecordFunction to adjust the sampling rates of + // the callbacks + bool pre_sampled = false; + if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { + // Check if we need to run callbacks registered with RecordFunction + // If true and callbacks need inputs, we box the arguments and pass + // them into the callbacks and also into the kernel call + + // Note: for perf reasons we wouldn't want to pass arguments into + // the function call or prematurely box them + at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) { + int64_t seq_num = -1; + // Setting sequence number in the Autograd case to associate + // the forward range with the coresponding Autograd's node + if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { + seq_num = at::sequence_number::peek(); + } + if (guard.needsInputs()) { + torch::jit::Stack stack = impl::boxArgs(args...); + guard.before(op, stack, seq_num); + } else { + guard.before(op, seq_num); + } } } + // keeping the guard alive while executing the kernel + return kernel.template call(op, std::forward(args)...); } #endif // PYTORCH_DISABLE_PER_OP_PROFILING return kernel.template call(op, std::forward(args)...); @@ -429,20 +440,26 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const const auto& kernel = entry.lookup(dispatchKey); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - // using already existing stack to record function execution in observers - at::RecordFunction guard(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(guard.isActive())) { - if (shouldRecord(dispatchKey) && entry.isObserved()) { - int64_t seq_num = -1; - if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { - seq_num = at::sequence_number::peek(); - } - if (guard.needsInputs()) { - guard.before(op, *stack, seq_num); - } else { - guard.before(op, seq_num); + bool pre_sampled = false; + if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { + // using already existing stack to record function execution in observers + at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + if (shouldRecord(dispatchKey) && entry.isObserved()) { + int64_t seq_num = -1; + if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { + seq_num = at::sequence_number::peek(); + } + if (guard.needsInputs()) { + guard.before(op, *stack, seq_num); + } else { + guard.before(op, seq_num); + } } } + // keeping the guard alive while executing the kernel + kernel.callBoxed(op, stack); + return; } #endif // PYTORCH_DISABLE_PER_OP_PROFILING kernel.callBoxed(op, stack); diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 5a0efffea261..70247924c736 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -39,6 +39,7 @@ namespace c10 { _(prim, FunctionalGraph) \ _(prim, DifferentiableGraph) \ _(prim, TensorExprGroup) \ + _(prim, StaticSubgraph) \ _(prim, If) \ _(prim, Jump) /* debug */ \ _(prim, JumpNZ) /* debug */ \ @@ -139,6 +140,7 @@ namespace c10 { _(prim, HasAttr) \ _(prim, profile) \ _(prim, profile_optional) \ + _(prim, profile_ivalue) \ _(prim, AddStatValue) \ _(prim, TimePoint) \ _(prim, CallFunction) \ diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6b8f4412cbf7..60382e37b6ff 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -22,7 +22,7 @@ namespace ivalue { // This is in ivalue.cpp because we need to access Type::annotation_str, which // is declared in jit_type.h -void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { +void checkCustomClassType(const Type* expected_type, const Type* actual_type) { // NB: doing pointer comparison here // If in the future there ever arises a need to call operator== on custom class // Type's, this needs to be changed! diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 9ea18dc8482d..d2e72933b532 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -949,8 +949,8 @@ TORCH_API ska::flat_hash_map& getCustomClassTypeMap(); template -c10::ClassTypePtr getCustomClassType() { - auto tmap = c10::getCustomClassTypeMap(); +c10::ClassTypePtr getCustomClassTypeImpl() { + auto& tmap = c10::getCustomClassTypeMap(); auto res = tmap.find(std::type_index(typeid(T))); if (res == tmap.end()) { throw c10::Error("Can't find class id in custom class type map", ""); @@ -959,9 +959,13 @@ c10::ClassTypePtr getCustomClassType() { } template -inline bool isCustomClassRegistered() { - auto tmap = c10::getCustomClassTypeMap(); - return tmap.find(std::type_index(typeid(T))) != tmap.end(); +const c10::ClassTypePtr& getCustomClassType() { + // Classes are never unregistered from getCustomClassTypeMap and the + // hash lookup can be a hot path, so just cache. + // For the same reason, it's fine If this ends up getting duplicated across + // DSO boundaries for whatever reason. + static c10::ClassTypePtr cache = getCustomClassTypeImpl(); + return cache; } TORCH_API std::unordered_map>& diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 3068bda5f5a5..b3b53aed994c 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -172,7 +172,7 @@ inline at::Generator IValue::toGenerator() const& { namespace ivalue { void CAFFE2_API -checkCustomClassType(TypePtr expected_type, TypePtr actual_type); +checkCustomClassType(const Type* expected_type, const Type* actual_type); template using Shared = c10::intrusive_ptr; @@ -290,18 +290,22 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { /** * Wait on the future until it completes. */ - virtual void wait() { + void wait() { std::unique_lock lock(mutex_); while (!completed_) { finished_cv_.wait(lock); } + + if (!eptr_) { + postWaitHook(value_); + } } /** * Wait on the future until it completes and throw an * exception if an error exists. */ - virtual void waitAndThrow() { + void waitAndThrow() { std::unique_lock lock(mutex_); while (!completed_) { finished_cv_.wait(lock); @@ -310,12 +314,14 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { if (eptr_) { std::rethrow_exception(eptr_); } + + postWaitHook(value_); } /** * Explicitly mark the future as completed with the output value. */ - virtual void markCompleted(IValue value) { + void markCompleted(IValue value) { std::unique_lock lock(mutex_); TORCH_CHECK( !completed(), @@ -324,6 +330,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { completed_ = true; value_ = std::move(value); + postMarkCompletedHook(value_); + std::vector> cbs; cbs.swap(callbacks_); lock.unlock(); @@ -359,7 +367,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Get the result of the current future. - virtual IValue value() { + IValue value() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); if (eptr_) { @@ -370,7 +378,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { // This accessor should only be used if we know that the future is // completed() with no error. - virtual const IValue& constValue() { + const IValue& constValue() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); AT_ASSERT(!eptr_); @@ -383,8 +391,9 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { * If the future has already completed, * this function will execute the callback immediately. */ - virtual void addCallback(std::function callback) { + void addCallback(std::function callback) { std::unique_lock lock(mutex_); + callback = wrapCallback(std::move(callback)); if (completed()) { lock.unlock(); callback(); @@ -398,31 +407,47 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { * value of the callback. This is necessary when the callback provider needs * to know for sure when the callback has finished. */ - virtual c10::intrusive_ptr then( + c10::intrusive_ptr then( std::function callback, TypePtr type) { - auto fut = c10::make_intrusive(type); - // Cannot move capture std::function in lambda, because it cannot deduce - // the template type for std::function. Hence use std::bind to explicitly - // specify types. - addCallback(std::bind( - [fut](std::function cb) { + auto fut = createInstance(std::move(type)); + addCallback( + [fut, cb = std::move(callback)]() { try { fut->markCompleted(cb()); } catch (std::exception&) { fut->setError(std::current_exception()); } - }, - std::move(callback))); + }); return fut; } - // Since this file cannot import CUDA depedency, the type of the seocond arg - // in the callback is c10::Stream instead of at::cuda::CUDAStream, and - // CUDAStream is constructed on the fly. The default implementation - // is a no-op, since it does not deal with any CUDA streams. - virtual void setRecordStreamCallback( - std::function record_stream_cb) {} + // Some subclasses deal with CUDA tensors and must inform the CUDA caching + // allocator of which CUDA streams each DataPtr is used in. If the value held + // by the future is a Python object we need to acquire the GIL when extracting + // these DataPtrs. Since this file cannot depend on Python, we allow users to + // provide a "custom" extractor. Look for example at the PythonFutureWrapper. + using DataPtrExtractor = + std::function>( + const at::IValue&)>; + virtual void setDataPtrExtractor(DataPtrExtractor data_ptr_extractor) {} + + // Expose the default implementation so that external ones can defer to it. + static std::vector> + defaultDataPtrExtractor(const at::IValue& value) { + at::IValue::HashAliasedIValues sub_values; + // Prefer getSubValues() over visit() as the latter is a silent no-op for + // some unsupported types, whereas the former at least fails loudly. + value.getSubValues(sub_values); + + std::vector> res; + for (const at::IValue& sub_value : sub_values) { + if (sub_value.isTensor()) { + res.emplace_back(sub_value.toTensor().storage().data_ptr()); + } + } + return res; + }; // Tries to retrieve the error message from std::exception_ptr. std::string tryRetrieveErrorMessage() { @@ -432,11 +457,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Check if the current future has completed - virtual bool completed() const { + bool completed() const { return completed_; } - virtual bool hasValue() const { + bool hasValue() const { std::unique_lock lock(mutex_); return completed_ && !eptr_; } @@ -459,6 +484,43 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { return type_; } + protected: + // This hook is called by this class's then() method when it prepares the + // instance it returns to the caller. It should be overridden by subclasses so + // that they can produce an instace of their own type. + virtual c10::intrusive_ptr createInstance(at::TypePtr type) { + return c10::make_intrusive(type); + } + + // This hook will be called by this class (the superclass) when the future is + // marked completed _with a value_ (hence not in case of error). This is done + // right away, while the mutex is still held, before any callbacks are run. + // It allows subclasses to further update their state if they so need. For + // example the CUDAFuture subclass uses it to determine what devices the value + // resides on and record an event in those devices' current streams. + virtual void postMarkCompletedHook(const at::IValue& value) {} + + // This hook will be called by the addCallback() and the then() methods before + // storing the callback for later execution (or before running it inline if + // the future is already complete). Note that this method could thus be called + // while the future is _not_ yet complete. By default this method does nothing + // but subclasses can override this method to add functionality. For example + // the CUDAFuture subclass ensures the callback runs with CUDA streams which + // are synchronized with the events recorded in the I/O streams. + virtual std::function wrapCallback( + std::function callback) { + return callback; + } + + // This hook will be called by this class after a user thread has completed + // waiting on a successful future. It will thus not be called if the future + // completes with an error. It will also not be called if the user accesses + // the future's value without synchronization. Subclasses can override this + // to add some synchronization to the wait. For example, the CUDAFuture + // subclass ensures the user's current CUDA streams synchronize with the I/O + // events stored by the future. + virtual void postWaitHook(const at::IValue& value) {} + private: void setErrorInternal( std::exception_ptr eptr, @@ -467,6 +529,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { completed_ = true; eptr_ = std::move(eptr); + // Do not call postMarkCompletedHook() here as there isn't any value. + std::vector> cbs; cbs.swap(callbacks_); lock.unlock(); @@ -756,8 +820,8 @@ c10::intrusive_ptr IValue::toCustomClass() && { obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; @@ -774,8 +838,8 @@ c10::intrusive_ptr IValue::toCustomClass() const& { obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; @@ -1085,13 +1149,16 @@ template < typename T, std::enable_if_t::value, int>> IValue::IValue(c10::intrusive_ptr custom_class) { - if (!c10::isCustomClassRegistered>()) { - throw c10::Error( - "Trying to instantiate a class that isn't a registered custom class: " + - std::string(c10::util::get_fully_qualified_type_name()), - ""); - } - auto classType = c10::getCustomClassType>(); + TypePtr classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + + std::string(c10::util::get_fully_qualified_type_name()), + ""); + } + }(); auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 40c2ec7f443d..7fcd5c2d17e9 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -630,8 +630,7 @@ struct CAFFE2_API TensorType : public Type { const SymbolicShape& sizes, const VaryingShape& stride_, c10::optional requires_grad, - c10::optional undefined = false, - bool is_inferred = false); + c10::optional undefined = false); static TensorTypePtr create( c10::optional scalar_type, @@ -776,10 +775,13 @@ struct CAFFE2_API TensorType : public Type { static TensorTypePtr getInferred() { static auto valueInferred = TensorType::create( - /*scalar_type=*/{}, /*device=*/{}, - /*sizes=*/SymbolicShape(), - /*stride=*/VaryingShape{}, /*requires_grad=*/{}, - /*undefined=*/false, /*is_inferred=*/true); + /*scalar_type=*/{}, + /*device=*/{}, + /*sizes=*/SymbolicShape(), + /*stride=*/VaryingShape{}, + /*requires_grad=*/{}, + /*undefined=*/false); + valueInferred->is_inferred_ = true; return valueInferred; } @@ -808,6 +810,17 @@ struct CAFFE2_API TensorType : public Type { static const TypeKind Kind = TypeKind::TensorType; + static std::vector contiguousStridesOf(at::IntArrayRef sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) // zero-dim case + return strides; + strides.back() = 1; + for (size_t i = strides.size() - 1; i > 0; i--) { + strides[i - 1] = strides[i] * sizes[i]; + } + return strides; + } + private: TensorType( c10::optional scalar_type, @@ -822,17 +835,6 @@ struct CAFFE2_API TensorType : public Type { scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_)); } - static std::vector contiguousStridesOf(at::IntArrayRef sizes) { - std::vector strides(sizes.size()); - if (sizes.empty()) // zero-dim case - return strides; - strides.back() = 1; - for (size_t i = strides.size() - 1; i > 0; i--) { - strides[i - 1] = strides[i] * sizes[i]; - } - return strides; - } - static VaryingShape computeStrideProps( at::IntArrayRef sizes, at::IntArrayRef strides, @@ -1725,13 +1727,18 @@ namespace detail { template struct getTypePtr_ final { static TypePtr call() { - TORCH_CHECK( - isCustomClassRegistered(), - "Type ", - c10::util::get_fully_qualified_type_name(), - " could not be converted to any of the known types." - ); - auto res = getCustomClassType(); + TypePtr res = []() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }(); return std::dynamic_pointer_cast(std::move(res)); } }; diff --git a/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h b/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h index ea7a5bd0b54c..50c90937548f 100644 --- a/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h +++ b/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h @@ -207,12 +207,110 @@ constexpr auto with_explicit_optional_tensors(KernelFunc kernel_func) { return kernel_func; } +template constexpr bool is_out_argument_() { + return std::is_same::value; } +template using is_out_argument = guts::bool_constant()>; -template +template +struct with_out_arguments_reordered_impl final { +private: + // For an example op + // > aten::example(Tensor a, int64_t b, int64_t c, Tensor(a!) out_d, Tensor(b!) out_e) -> (Tensor(a!), Tensor(b!)) + // we get a KernelFunc + // > KernelFunc = std::tuple example(Tensor& out_d, Tensor& out_e, const Tensor& a, int64_t b, int64_t c) + // > NumOutParameters = 2 + // with the out arguments at the front, and reorder that into + // > std::tuple example(const Tensor& a, int64_t b, int64_t c, Tensor& out_d, Tensor& out_e) + // where the out arguments are in the back. + + using kernel_signature_traits = guts::infer_function_traits_t; + + // Assert that the KernelFunc is what we expect. The following block is + // not strictly necessary for the metaprogramming here, it's just a check. + static_assert( + guts::typelist::all< + is_out_argument, + guts::typelist::take_t< + typename kernel_signature_traits::parameter_types, + NumOutParameters + > + >::value, + "The kernel function has the wrong number of leading Tensor& arguments to match the out arguments in the JIT signature" + ); + + static constexpr size_t num_parameters = kernel_signature_traits::number_of_parameters; + static constexpr size_t num_nonout_parameters = num_parameters - NumOutParameters; + + // kernel_to_schema_permutation_indices contains a mapping from argument index in KernelFunc to the corresponding + // argument index in the schema. + // For the aten::example op, that'll be + // > kernel_to_schema_permutation_indices = [3, 4, 0, 1, 2] + // Interpreted as a mapping, this means + // - argument 0 in KernelFunc maps to argument 3 in the schema, + // - argument 1 in KernelFunc maps to argument 4 in the schema, + // - argument 2 in KernelFunc maps to argument 0 in the schema, + // - ... + // We can use this as a permutation function to reorder types or values correspondingly + using kernel_to_schema_permutation_indices = guts::concat_iseq_t< + guts::make_offset_index_sequence, + std::make_index_sequence + >; + + // For types, we need the inverse permutation because parameters (i.e. types) and arguments (i.e. values) + // need to be mapped in inverted directions. For types, we generate the schema order types from + // the KernelFunction types, but for arguments we get schema order arguments and need to generate + // the KernelFunction arguments. + // That's why in this reordering, we use NumOutParameters instead of the num_nonout_parameters we used above. + using target_signature_parameters = guts::typelist::concat_t< + guts::typelist::drop_t, + guts::typelist::take_t + >; + + template + struct wrapper_; + template + struct wrapper_, std::index_sequence> { + static Return call(Parameters... args) { + // call through to KernelFunc but reorder arguments as determined + // by the permutation we calculated above. + return (*KernelFunc::func_ptr())(std::get(std::tuple(std::forward(args)...))...); + } + }; + +public: + using wrapper = wrapper_; +}; + + +/** + * Take a kernel function that has a number of `Tensor`, `const Tensor&` or `Tensor&` arguments + * where all `Tensor&` arguments are at the beginning, and take NumOutParameters. + * Create a wrapper function that has `NumOutParameters` `Tensor&` arguments at the end + * and calls through the underlying kernel function by reordering them to the front. + */ +template 0), int> = 0> +constexpr auto with_out_arguments_reordered(KernelFunc kernel_func) { + // SFINAE case for kernels that have out tensor arguments. + // Wrap them and reorder the arguments. + using impl = with_out_arguments_reordered_impl; + return TORCH_FN((&impl::wrapper::call)); +} + +template = 0> +constexpr auto with_out_arguments_reordered(KernelFunc kernel_func) { + // SFINAE case for kernels that don't have out tensor arguments. + // Don't wrap them but just use the kernel directly. + return kernel_func; +} + +} + +template constexpr auto hacky_wrapper_for_legacy_signatures(FuncPtr kernel_func) { - auto with_tensoroptions_scattered = detail::with_scattered_tensor_options(kernel_func); - auto result = detail::with_explicit_optional_tensors(with_tensoroptions_scattered); + auto with_scattered_tensor_options = detail::with_scattered_tensor_options(kernel_func); + auto with_out_arguments_reordered = detail::with_out_arguments_reordered(with_scattered_tensor_options); + auto result = detail::with_explicit_optional_tensors(with_out_arguments_reordered); static_assert(std::is_same::value, "Generated signature doesn't match the expected one."); return result; }; diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 276e3a6838a3..429007e4242b 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -978,11 +978,9 @@ TensorTypePtr TensorType::create( const SymbolicShape& sizes, const VaryingShape& strides, c10::optional requires_grad, - c10::optional undefined, - bool is_inferred) { - auto pt = TensorTypePtr(new TensorType( + c10::optional undefined) { + auto pt = TensorTypePtr(new TensorType( scalar_type, device, sizes, strides, requires_grad, undefined)); - pt->is_inferred_ = is_inferred; return pt; } diff --git a/aten/src/ATen/cpu/vec256/missing_vst1_neon.h b/aten/src/ATen/cpu/vec256/missing_vst1_neon.h index dbb2ba479f85..dffd5dbb862e 100644 --- a/aten/src/ATen/cpu/vec256/missing_vst1_neon.h +++ b/aten/src/ATen/cpu/vec256/missing_vst1_neon.h @@ -4,6 +4,5 @@ __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_f32_x2 (float32_t * __a, float32x4x2_t val) { - asm ("st1 {%S0.4s - %T0.4s}, [%1]" :: "w" (val), "r"(__a) :); + asm ("st1 {%S1.4s - %T1.4s}, [%2]" : "=m" (*__a) : "w" (val), "r"(__a) : "memory"); } - diff --git a/aten/src/ATen/cpu/vec256/vec256.h b/aten/src/ATen/cpu/vec256/vec256.h index 96d17a9e1afa..ae40b9a5b4fd 100644 --- a/aten/src/ATen/cpu/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec256/vec256.h @@ -6,6 +6,7 @@ #include #include +#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) #include #include #include @@ -14,6 +15,9 @@ #include #include #include +#else +#include +#endif #include #include diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index 58a677dc5de0..43389fe61583 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -25,7 +25,7 @@ static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { __m256i lo = _mm256_castps_si256(a); __m256i hi = _mm256_castps_si256(b); - __m256i nan = _mm256_set1_epi32(0x7fc0); + __m256i nan = _mm256_set1_epi32(0xffff); __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); __m256i ones = _mm256_set1_epi32(0x1); diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_common_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_common_vsx.h new file mode 100644 index 000000000000..516179932d34 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_common_vsx.h @@ -0,0 +1,216 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace at { +namespace vec256 { + +namespace { + +DEFINE_CLAMP_FUNCS(c10::quint8) +DEFINE_CLAMP_FUNCS(c10::qint8) +DEFINE_CLAMP_FUNCS(c10::qint32) +DEFINE_CLAMP_FUNCS(int16_t) +DEFINE_CLAMP_FUNCS(int32_t) +DEFINE_CLAMP_FUNCS(int64_t) +DEFINE_CLAMP_FUNCS(float) +DEFINE_CLAMP_FUNCS(double) + +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + vec_madd(a.vec0(), b.vec0(), c.vec0()), + vec_madd(a.vec1(), b.vec1(), c.vec1())}; +} + +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +template <> +Vec256 C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vec256& src) { + return Vec256{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +Vec256 C10_ALWAYS_INLINE +convert_to_int_of_same_size( + const Vec256& src) { + return Vec256{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vec256::size()); i += Vec256::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint32 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat32 c0 = vec_float(input_vec0); + vfloat32 c1 = vec_float(input_vec1); + vec_vsx_st(c0, offset0, dst_a); + vec_vsx_st(c1, offset16, dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vec256::size()); i += Vec256::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + vint64 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint64 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat64 c0 = vec_double(input_vec0); + vfloat64 c1 = vec_double(input_vec1); + vec_vsx_st(c0, offset0, reinterpret_cast(dst_a)); + vec_vsx_st(c1, offset16, reinterpret_cast(dst_a)); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +std::pair, Vec256> inline interleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0); + vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3); + vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0); + vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3); + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vec256{ab00, ab11}, Vec256{ab2_00, ab2_11}); +} + +template <> +std::pair, Vec256> inline deinterleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0); + vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0); + + vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3); + vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3); + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vec256{aa01, aa23}, Vec256{bb_01, bb_23}); +} + +template <> +std::pair, Vec256> inline interleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + + vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vec256{ab0011, ab2233}, Vec256{ab2_0011, ab2_2233}); +} + +template <> +std::pair, Vec256> inline deinterleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vec256{aa0123, aa0123_2}, Vec256{bb0123, bb0123_2}); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h new file mode 100644 index 000000000000..f62ac36850be --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h @@ -0,0 +1,597 @@ +#pragma once +#include +#include +#include +#include + +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { +using ComplexDbl = c10::complex; + +template <> +class Vec256 { + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexDbl; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + static constexpr int size() { + return 2; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {} + + Vec256(ComplexDbl val) { + double real_value = val.real(); + double imag_value = val.imag(); + _vec0 = vfloat64{real_value, imag_value}; + _vec1 = vfloat64{real_value, imag_value}; + } + Vec256(ComplexDbl val1, ComplexDbl val2) { + _vec0 = vfloat64{val1.real(), val1.imag()}; + _vec1 = vfloat64{val2.real(), val2.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {a._vec0, b._vec1}; + } + + template + static Vec256 C10_ALWAYS_INLINE + el_blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vec256 blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = + Vec256(vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0)); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vecb0), + vec_sel(a._vec1, b._vec1, mask_complex._vecb1)}; + } + + static Vec256 C10_ALWAYS_INLINE elwise_blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vec256 arange( + ComplexDbl base = 0., + step_t step = static_cast(1)) { + return Vec256(base, base + step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexDbl& operator[](int idx) const = delete; + ComplexDbl& operator[](int idx) = delete; + + Vec256 map(ComplexDbl (*f)(ComplexDbl)) const { + __at_align32__ ComplexDbl tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vec256 map(ComplexDbl (*f)(const ComplexDbl&)) const { + __at_align32__ ComplexDbl tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vec256 el_swapped() const { + vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2); + vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2); + return {v0, v1}; + } + + Vec256 el_madd( + const Vec256& multiplier, + const Vec256& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + Vec256 el_mergeo() const { + vfloat64 v0 = vec_splat(_vec0, 1); + vfloat64 v1 = vec_splat(_vec1, 1); + return {v0, v1}; + } + + Vec256 el_mergee() const { + vfloat64 v0 = vec_splat(_vec0, 0); + vfloat64 v1 = vec_splat(_vec1, 0); + return {v0, v1}; + } + + static Vec256 el_mergee( + Vec256& first, + Vec256& second) { + // as mergee phased in , we can use vec_perm with mask + return { + vec_mergeh(first._vec0, second._vec0), + vec_mergeh(first._vec1, second._vec1)}; + } + + Vec256 abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a; + } + + Vec256 abs_() const { + auto ret = abs_2_(); + return ret.elwise_sqrt(); + } + + Vec256 abs() const { + return abs_() & vd_real_mask; + } + + Vec256 angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_pd(values, 0x05); // b a + // return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + auto ret = el_swapped(); + for (int i = 0; i < 2; i++) { + ret._vec0[i] = std::atan2(_vec0[i], ret._vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i], ret._vec0[i]); + } + return ret; + } + + Vec256 angle() const { + auto a = angle_().el_swapped(); + return a & vd_real_mask; + } + + Vec256 real_() const { + return *this & vd_real_mask; + } + Vec256 real() const { + return *this & vd_real_mask; + } + Vec256 imag_() const { + return *this & vd_imag_mask; + } + Vec256 imag() const { + return imag_().el_swapped(); + } + + Vec256 conj_() const { + return *this ^ vd_isign_mask; + } + Vec256 conj() const { + return *this ^ vd_isign_mask; + } + + Vec256 log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vec256 log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(vd_log2e_inv); + } + Vec256 log10() const { + auto ret = log(); + return ret.elwise_mult(vd_log10e_inv); + } + + Vec256 asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub(val_2, val_2_swapped); + re = Vec256(vd_one) - re; + auto root = el_blend<0x0A>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); + } + + Vec256 acos() const { + // acos(x) = pi/2 - asin(x) + return Vec256(vd_pi_2) - asin(); + } + + Vec256 atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vec256(vd_imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * vd_imag_half; // i/2*ln() + } + + Vec256 sin() const { + return map(std::sin); + } + Vec256 sinh() const { + return map(std::sinh); + } + Vec256 cos() const { + return map(std::cos); + } + Vec256 cosh() const { + return map(std::cosh); + } + + Vec256 tan() const { + return map(std::tan); + } + Vec256 tanh() const { + return map(std::tanh); + } + Vec256 ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 neg() const { + auto z = Vec256(vd_zero); + return z - *this; + } + Vec256 round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vec256 trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << ","; + std::cout << _vec1[0] << "," << _vec1[1] << std::endl; + } + + Vec256 sqrt() const { + return map(std::sqrt); + } + + Vec256 reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ vd_isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vec256 rsqrt() const { + return sqrt().reciprocal(); + } + + static Vec256 horizontal_add( + Vec256& first, + Vec256& second) { + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + static Vec256 horizontal_sub( + Vec256& first, + Vec256& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vec256 inline operator*(const Vec256& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ vd_rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); +#else + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub(ac_bd, ad_bc); +#endif + return ret; + } + + Vec256 inline operator/(const Vec256& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() +#if 1 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + auto abs_b = b.abs_2_(); + vi = vi ^ vd_isign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); + ret = ret.elwise_div(abs_b); +#else + // Vec256 x86 simulation + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_rsign_mask; + auto ad_bc = elwise_mult(d_c); + auto abs_b = b.abs_2_(); + auto re_im = horizontal_add(ac_bd, ad_bc); + auto ret = re_im.elwise_div(abs_b); +#endif + return ret; + } + + Vec256 exp() const { + return map(std::exp); + } + + Vec256 pow(const Vec256& exp) const { + __at_align32__ ComplexDbl x_tmp[size()]; + __at_align32__ ComplexDbl y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vec256 sgn() const { + return map(at::native::sgn_impl); + } + + Vec256 hypot(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 nextafter(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igamma(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igammac(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 log1p() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 atan2(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 erf() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 erfc() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 expm1() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator<(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 operator<=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 operator>(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 operator>=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 eq(const Vec256& other) const { + auto ret = (*this == other); + return ret & vd_one; + } + Vec256 ne(const Vec256& other) const { + auto ret = (*this != other); + return ret & vd_one; + } + + Vec256 lt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 le(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 gt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 ge(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor) + // elelemtwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vec256::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vec256::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + + +} // namespace +} // namespace vec256 +} // namespace at + diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h new file mode 100644 index 000000000000..cb9b4c90fbe0 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h @@ -0,0 +1,670 @@ + +#pragma once +#include +#include +#include +#include + +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { +using ComplexFlt = c10::complex; + +template <> +class Vec256 { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexFlt; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + + static constexpr int size() { + return 4; + } + Vec256() {} + + C10_ALWAYS_INLINE Vec256(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + + Vec256(ComplexFlt val) { + float real_value = val.real(); + float imag_value = val.imag(); + _vec0 = vfloat32{real_value, imag_value, real_value, imag_value}; + _vec1 = vfloat32{real_value, imag_value, real_value, imag_value}; + } + + Vec256(ComplexFlt val1, ComplexFlt val2, ComplexFlt val3, ComplexFlt val4) { + _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()}; + _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + const vbool32 mask_2nd = VsxComplexMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static Vec256 C10_ALWAYS_INLINE + el_blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vec256 blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vec256( + vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1)); + // mask_complex.dump(); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vec0), + vec_sel(a._vec1, b._vec1, mask_complex._vec1), + }; + } + + static Vec256 elwise_blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vec0), + vec_sel(a._vec1, b._vec1, mask._vec1), + }; + } + + template + static Vec256 arange( + ComplexFlt base = 0., + step_t step = static_cast(1)) { + return Vec256( + base, + base + step, + base + ComplexFlt(2) * step, + base + ComplexFlt(3) * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexFlt& operator[](int idx) const = delete; + ComplexFlt& operator[](int idx) = delete; + + Vec256 map(ComplexFlt (*f)(ComplexFlt)) const { + __at_align32__ ComplexFlt tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vec256 map(ComplexFlt (*f)(const ComplexFlt&)) const { + __at_align32__ ComplexFlt tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + static Vec256 horizontal_add_permD8( + Vec256& first, + Vec256& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + static Vec256 horizontal_sub_permD8( + Vec256& first, + Vec256& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vec256 abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a.el_mergee(); + } + + Vec256 abs_() const { + auto ret = abs_2_(); + return ret.elwise_sqrt(); + } + + Vec256 abs() const { + return abs_() & real_mask; + } + + Vec256 real_() const { + return *this & real_mask; + } + Vec256 real() const { + return *this & real_mask; + } + Vec256 imag_() const { + return *this & imag_mask; + } + Vec256 imag() const { + // we can use swap_mask or sldwi + auto ret = imag_(); + return { + vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)}; + } + + Vec256 conj_() const { + return *this ^ isign_mask; + } + Vec256 conj() const { + return *this ^ isign_mask; + } + + Vec256 log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vec256 log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(log2e_inv); + } + Vec256 log10() const { + auto ret = log(); + return ret.elwise_mult(log10e_inv); + } + + Vec256 el_swapped() const { + vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask); + vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + Vec256 el_mergee() const { + // as mergee phased in , we can use vec_perm with mask + return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)}; + } + + Vec256 el_mergeo() const { + // as mergeo phased in , we can use vec_perm with mask + return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)}; + } + + Vec256 el_madd( + const Vec256& multiplier, + const Vec256& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + static Vec256 el_mergee( + Vec256& first, + Vec256& second) { + // as mergee phased in , we can use vec_perm with mask + return { + vec_mergee(first._vec0, second._vec0), + vec_mergee(first._vec1, second._vec1)}; + } + + Vec256 angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_ps(values, 0xB1); // b a + // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + auto ret = el_swapped(); + for (int i = 0; i < 4; i++) { + ret._vec0[i] = std::atan2(_vec0[i], ret._vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i], ret._vec0[i]); + } + return ret; + } + + Vec256 angle() const { + auto a = angle_().el_swapped(); + return a & real_mask; + } + + Vec256 sin() const { + return map(std::sin); + } + Vec256 sinh() const { + return map(std::sinh); + } + Vec256 cos() const { + return map(std::cos); + } + Vec256 cosh() const { + return map(std::cosh); + } + Vec256 ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 neg() const { + auto z = Vec256(zero); + return z - *this; + } + Vec256 round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vec256 tan() const { + return map(std::tan); + } + Vec256 tanh() const { + return map(std::tanh); + } + Vec256 trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << "," << _vec0[2] << "," + << _vec0[3] << ","; + std::cout << _vec1[0] << "," << _vec1[1] << "," << _vec1[2] << "," + << _vec1[3] << std::endl; + } + + Vec256 sqrt() const { + return map(std::sqrt); + } + + Vec256 reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vec256 rsqrt() const { + return sqrt().reciprocal(); + } + + Vec256 pow(const Vec256& exp) const { + __at_align32__ ComplexFlt x_tmp[size()]; + __at_align32__ ComplexFlt y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vec256 atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vec256(imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * imag_half; // i/2*ln() + } + + Vec256 acos() const { + // acos(x) = pi/2 - asin(x) + return Vec256(pi_2) - asin(); + } + + Vec256 inline operator*(const Vec256& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); + return ret; + +#else + + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub_permD8(ac_bd, ad_bc); + return ret; +#endif + } + + Vec256 inline operator/(const Vec256& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() +#if 1 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + auto abs_b = b.abs_2_(); + vi = vi ^ isign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); + ret = ret.elwise_div(abs_b); +#else + // Vec256 x86 simulation + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ rsign_mask; + auto ad_bc = elwise_mult(d_c); + auto abs_b = b.abs_2_(); + auto re_im = horizontal_add_permD8(ac_bd, ad_bc); + auto ret = re_im.elwise_div(abs_b); +#endif + return ret; + } + + Vec256 asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + +#if 1 + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub_permD8(val_2, val_2_swapped); + re = Vec256(one) - re; + auto root = el_blend<0xAA>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); +#else + return map(std::asin); +#endif + } + + Vec256 exp() const { + return map(std::exp); + } + + Vec256 eq(const Vec256& other) const { + auto ret = (*this == other); + return ret & one; + } + Vec256 ne(const Vec256& other) const { + auto ret = (*this != other); + return ret & one; + } + + Vec256 sgn() const { + return map(at::native::sgn_impl); + } + + Vec256 hypot(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 nextafter(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igamma(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igammac(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 atan2(const Vec256& b) const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + Vec256 erf() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + Vec256 erfc() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + + Vec256 log1p() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + + Vec256 expm1() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + + Vec256 operator<(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator<=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator>(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator>=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 lt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 le(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 gt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 ge(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor) + // elelemtwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vec256::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vec256::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_double_vsx.h new file mode 100644 index 000000000000..f34bdc7bbcb3 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_double_vsx.h @@ -0,0 +1,392 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace vec256 { + +namespace { + + +template <> +class Vec256 { + private: + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = double; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + static constexpr int size() { + return 4; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(double scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + double scalar1, + double scalar2, + double scalar3, + double scalar4) + : _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + int zero_mask() const { + auto cmp = (*this == vd_zero); + return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) | + (cmp._vecb1[1] & 8); + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return { b._vec0, a._vec1 }; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return { a._vec0, b._vec1 }; + } + + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1 }; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1 }; + } + + + template + static std::enable_if_t> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return { a._vec0, + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) }; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return { b._vec0, + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) }; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) }; + } + + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + static Vec256 arange(double base = 0., double step = 1.) { + return Vec256(base, base + step, base + 2 * step, base + 3 * step); + } + + static Vec256 C10_ALWAYS_INLINE + set(const Vec256& a, const Vec256& b, size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << "," << _vec1[0] << "," << _vec1[1] << std::endl; + } + Vec256 map(double (*f)(double)) const { + Vec256 ret; + for (int i = 0; i < size()/2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size()/2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vec256 mapbi(double (*f)(double, double), const Vec256& other) + const { + Vec256 ret; + for (int i = 0; i < size()/2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size()/2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE acos() const { + return {Sleef_acosd2_u10vsx(_vec0), Sleef_acosd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE asin() const { + return {Sleef_asind2_u10vsx(_vec0), Sleef_asind2_u10vsx(_vec1)}; + } + Vec256 atan() const { + return {Sleef_atand2_u10vsx(_vec0), Sleef_atand2_u10vsx(_vec1)}; + } + Vec256 atan2(const Vec256& b) const { + return {Sleef_atan2d2_u10vsx(_vec0, b._vec0), Sleef_atan2d2_u10vsx(_vec1, b._vec1)}; + } + Vec256 erf() const { + return {Sleef_erfd2_u10vsx(_vec0), Sleef_erfd2_u10vsx(_vec1)}; + } + Vec256 erfc() const { + return {Sleef_erfcd2_u15vsx(_vec0), Sleef_erfcd2_u15vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE exp() const { + return {Sleef_expd2_u10vsx(_vec0), Sleef_expd2_u10vsx(_vec1)}; + } + Vec256 expm1() const { + return {Sleef_expm1d2_u10vsx(_vec0), Sleef_expm1d2_u10vsx(_vec1)}; + } + + Vec256 lgamma() const __ubsan_ignore_undefined__ { + return {Sleef_lgammad2_u10vsx(_vec0), Sleef_lgammad2_u10vsx(_vec1)}; + } + + Vec256 erfinv() const { + return map(calc_erfinv); + } + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE log() const { + return {Sleef_logd2_u10vsx(_vec0), Sleef_logd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE log10() const { + return {Sleef_log10d2_u10vsx(_vec0), Sleef_log10d2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pd2_u10vsx(_vec0), Sleef_log1pd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE log2() const { + return {Sleef_log2d2_u10vsx(_vec0), Sleef_log2d2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE cos() const { + return {Sleef_cosd2_u10vsx(_vec0), Sleef_cosd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshd2_u10vsx(_vec0), Sleef_coshd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE sin() const { + return {Sleef_sind2_u10vsx(_vec0), Sleef_sind2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhd2_u10vsx(_vec0), Sleef_sinhd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE tan() const { + return {Sleef_tand2_u10vsx(_vec0), Sleef_tand2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhd2_u10vsx(_vec0), Sleef_tanhd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vec256 C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE reciprocal() const { + return { + vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one. + vec_div(vd_one, _vec1)}; + } + Vec256 C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vec256 C10_ALWAYS_INLINE pow(const Vec256& b) const { + return {Sleef_powd2_u10vsx(_vec0, b._vec0), Sleef_powd2_u10vsx(_vec1, b._vec1)}; + } + Vec256 C10_ALWAYS_INLINE fmod(const Vec256& b) const { + return {Sleef_fmodd2_vsx(_vec0, b._vec0),Sleef_fmodd2_vsx(_vec1, b._vec1)}; + } + + Vec256 hypot(const Vec256& b) const { + return {Sleef_hypotd2_u05vsx(_vec0, b._vec0), Sleef_hypotd2_u05vsx(_vec1, b._vec1)}; + } + + Vec256 nextafter(const Vec256& b) const { + return {Sleef_nextafterd2_vsx(_vec0, b._vec0), Sleef_nextafterd2_vsx(_vec1, b._vec1)}; + } + + Vec256 igamma(const Vec256& x) const { + return mapbi(calc_igamma, x); + } + + Vec256 igammac(const Vec256& x) const { + return mapbi(calc_igammac, x); + } + + + Vec256 i0() const { + return map(calc_i0); + } + + DEFINE_MEMBER_OP(operator==, double, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, double, vec_cmpne) + DEFINE_MEMBER_OP(operator<, double, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, double, vec_cmple) + DEFINE_MEMBER_OP(operator>, double, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, double, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge) + DEFINE_MEMBER_OP(operator+, double, vec_add) + DEFINE_MEMBER_OP(operator-, double, vec_sub) + DEFINE_MEMBER_OP(operator*, double, vec_mul) + DEFINE_MEMBER_OP(operator/, double, vec_div) + DEFINE_MEMBER_OP(maximum, double, vec_max) + DEFINE_MEMBER_OP(minimum, double, vec_min) + DEFINE_MEMBER_OP(operator&, double, vec_and) + DEFINE_MEMBER_OP(operator|, double, vec_or) + DEFINE_MEMBER_OP(operator^, double, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd) +}; +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_float_vsx.h new file mode 100644 index 000000000000..2a1a87aa72c8 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_float_vsx.h @@ -0,0 +1,676 @@ +#pragma once + +#include +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] + +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = float; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + + static constexpr int size() { + return 8; + } + Vec256() {} + + C10_ALWAYS_INLINE Vec256(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(float scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + float scalar1, + float scalar2, + float scalar3, + float scalar4, + float scalar5, + float scalar6, + float scalar7, + float scalar8) + : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + static Vec256 arange(float base = 0.f, float step = 1.f) { + return Vec256( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + + Vec256 map(float (*f)(float)) const { + Vec256 ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vec256 mapbi(float (*f)(float, float), const Vec256& other) + const { + Vec256 ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + + Vec256 _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vec256 _isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + + Vec256 _isinf() const { + auto x = *this; + return (x == v_inf) | (x == v_minus_inf); + } + + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + auto cmp = (*this == zero); + // return _mm256_movemask_ps(cmp); + // possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5) + vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits); + vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits); + return (result0[1] >> 12 | (result1[1] >> 8)); + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE acos() const { + return {Sleef_acosf4_u10vsx(_vec0), Sleef_acosf4_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE asin() const { + return {Sleef_asinf4_u10vsx(_vec0), Sleef_asinf4_u10vsx(_vec1)}; + } + Vec256 atan() const { + return {Sleef_atanf4_u10vsx(_vec0), Sleef_atanf4_u10vsx(_vec1)}; + } + Vec256 atan2(const Vec256& b) const { + return {Sleef_atan2f4_u10vsx(_vec0, b._vec0), Sleef_atan2f4_u10vsx(_vec1, b._vec1)}; + } + + Vec256 lgamma() const { + return {Sleef_lgammaf4_u10vsx(_vec0), Sleef_lgammaf4_u10vsx(_vec1)}; + } + Vec256 erf() const { + return {Sleef_erff4_u10vsx(_vec0), Sleef_erff4_u10vsx(_vec1)}; + } + + Vec256 erfc() const { + return {Sleef_erfcf4_u15vsx(_vec0), Sleef_erfcf4_u15vsx(_vec1)}; + } + + Vec256 erfinv() const { + return map(calc_erfinv); + } + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE exp() const { + // implementation logic from avx_mathfun with some modifications from sleef + // Express e**x = e**g 2**n + /// = e**g e**( n loge(2) ) + /// = e**( g + n loge(2) ) + // + auto tmp_x = *this; + auto fx = (tmp_x * log2e_inv).round(); + + auto x = fx.madd(negln2f_hi, tmp_x); + x = fx.madd(negln2f_lo, x); + auto z = x * x; + auto y = x.madd(exp_p0, exp_p1); + y = y.madd(x, exp_p2); + y = y.madd(x, exp_p3); + y = y.madd(x, exp_p4); + y = y.madd(x, exp_p5); + y = y.madd(z, x) + one; + + // vm_pow2n 2^n + vint32 imm0 = vec_signed(fx._vec0); + vint32 imm1 = vec_signed(fx._vec1); + // this pow2n logic is from Sleef code + vint32 imm00 = imm0 >> 1; //>>1 + vint32 imm01 = imm1 >> 1; + vint32 imm10 = imm0 - imm00; + vint32 imm11 = imm1 - imm01; + imm00 = (imm00 + v0x7f) << vu_23; + imm01 = (imm01 + v0x7f) << vu_23; + imm10 = (imm10 + v0x7f) << vu_23; + imm11 = (imm11 + v0x7f) << vu_23; + // treat imm as float vector without conversion + + y._vec0 = (y._vec0 * (vfloat32)imm00) * (vfloat32)imm10; + y._vec1 = (y._vec1 * (vfloat32)imm01) * (vfloat32)imm11; + // boundary check + auto tmp = blendv(y, v_inf, (Vec256(exp_hi) <= tmp_x)); + y = blendv(tmp, zero, (tmp_x < Vec256(exp_lo))); + + return y; + } + Vec256 expm1() const { + return exp() - one; + } + + Vec256 C10_ALWAYS_INLINE log() const { + auto temp = *this; + auto invalid_mask = temp < zero; + // cut off denormalized stuff + auto x = temp.maximum(min_norm_pos); + vint32 imm0 = vec_sr(vint32(x._vec0), vu_23); + vint32 imm1 = vec_sr(vint32(x._vec1), vu_23); + // keep only the fractional part + x = x & inv_mant_mask; + x = x | half; + imm0 = imm0 - v0x7f; + imm1 = imm1 - v0x7f; + Vec256 ex; + ex._vec0 = vec_float(imm0); + ex._vec1 = vec_float(imm1); + ex = ex + one; + auto mask = x < cephes_SQRTHF; + auto t = x & mask; + x = x - one; + ex = ex - (mask & one); + x = x + t; + auto z = x * x; + auto y = x.madd(log_p0, log_p1); + y = y.madd(x, log_p2); + y = y.madd(x, log_p3); + y = y.madd(x, log_p4); + y = y.madd(x, log_p5); + y = y.madd(x, log_p6); + y = y.madd(x, log_p7); + y = y.madd(x, log_p8); + y = y * x * z; + y = ex.madd(log_q1, y); + y = y - z * half; + x = x + y; + x = ex.madd(log_q2, x); + // negative arg will be NAN + x = blendv(x, v_nan, invalid_mask); + // zero is -inf + x = blendv(x, min_inf, (temp == zero)); + return x; + } + Vec256 C10_ALWAYS_INLINE log10() const { + return log() * log10e_inv; + } + Vec256 C10_ALWAYS_INLINE log1p() const { + return ((*this) + one).log(); + } + Vec256 C10_ALWAYS_INLINE log2() const { + return log() * log2e_inv; + } + Vec256 C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE cos() const { + // take the absolute value + auto x = abs(); + // extract the sign bit (upper one) + auto sign_bit = (*this) & sign_mask; + // scale by 4/Pi + auto y = x * _4div_pi; + // store the integer part of y in mm0 + // j=(j+1) & (~1) (see the cephes sources) + vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; + vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; + y._vec0 = vec_float(imm0); + y._vec1 = vec_float(imm1); + + imm0 = imm0 - vi_2; + imm1 = imm1 - vi_2; + Vec256 poly_mask; + // get the swap sign flag + vint32 tmp0 = vec_and(vec_nand(imm0, imm0), vi_4); + vint32 tmp1 = vec_and(vec_nand(imm1, imm1), vi_4); + sign_bit._vecb0 = (vbool32)vec_sl(tmp0, vu_29); + sign_bit._vecb1 = (vbool32)vec_sl(tmp1, vu_29); + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi / 4 + // and another one for Pi / 4 < x <= Pi / 2 + // Both branches will be computed. + + poly_mask._vecb0 = (vbool32)vec_cmpeq((imm0 & vi_2), vi_0); + poly_mask._vecb1 = (vbool32)vec_cmpeq((imm1 & vi_2), vi_0); + + // The magic pass: "Extended precision modular arithmetic" + // x = ((x - y * DP1) - y * DP2) - y * DP3; + x = y.madd(minus_cephes_dp1, x); + x = y.madd(minus_cephes_dp2, x); + x = y.madd(minus_cephes_dp3, x); + + // Evaluate the first polynom (0 <= x <= Pi/4) + auto z = x * x; + y = z.madd(coscof_p0, coscof_p1); + y = y.madd(z, coscof_p2); + y = y * z * z; + y = y - z * half + one; + + // Evaluate the second polynom (Pi/4 <= x <= 0) + auto y_2 = z.madd(sincof_p0, sincof_p1); + y_2 = y_2.madd(z, sincof_p2); + y_2 = y_2 * z; + y_2 = y_2.madd(x, x); + + // select the correct result from the two polynoms + y = blendv(y, y_2, poly_mask); + // update the sign + y = y ^ sign_bit; + + return y; + } + Vec256 C10_ALWAYS_INLINE cosh() const { + // cosh = 1/2 * (e^x + e^-x) + auto x = abs(); + auto e_x = x.exp(); + auto ret = (e_x + Vec256(one) / e_x) * half; + // inf and nan checks +#if 0 + ret = blendv(ret, v_inf, x >= vf_89); + ret = blendv(ret, v_inf, ret._isnan()); + ret = blendv(ret, v_nan, this->_isnan()); +#endif + return ret; + } + Vec256 C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << "," << _vec0[2] << "," + << _vec0[3] << ","; + std::cout << _vec1[0] << "," << _vec1[1] << "," << _vec1[2] << "," + << _vec1[3] << std::endl; + } + + Vec256 C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE sin() const { + // take the absolute value and xtract sign + auto x = abs(); + auto sign_bit = (*this) & sign_mask; + + // scale by 4/Pi + auto y = x * _4div_pi; + // store the integer part of y in mm0 + + // j=(j+1) & (~1) (see the cephes sources) + vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; + vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; + y._vec0 = vec_float(imm0); + y._vec1 = vec_float(imm1); + // get the swap sign flag + Vec256 swap_sign_bit, poly_mask; + swap_sign_bit._vecb0 = (vbool32)vec_sl(imm0 & vi_4, vu_29); + swap_sign_bit._vecb1 = (vbool32)vec_sl(imm1 & vi_4, vu_29); + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4 C10_ALWAYS_INLINE sinh() const { + auto temp_abs = abs(); + // get exponent + auto ret = temp_abs.exp(); + auto recp = Vec256(half) / ret; + auto v = ret * half - recp; + // extract the sign bit (upper one) + auto sign_bit = (*this) & sign_mask; + auto z = temp_abs * temp_abs; + auto y = z.madd(p0, p1); + y = y.madd(z, p2); + y = (y * z).madd(temp_abs, temp_abs); + // check and select + auto result = blendv(y, v, temp_abs > one); + return result | sign_bit; + } + Vec256 C10_ALWAYS_INLINE tan() const { + return {Sleef_tanf4_u10vsx(_vec0), Sleef_tanf4_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE tanh() const { + auto x = *this; + auto vabs = abs(); + // get exponent + auto exp2x = (vabs + vabs).exp(); + auto vv = Vec256(one) - Vec256(two) / (exp2x + one); + // extract the sign bit (upper one) + auto sign_bit = (*this) & sign_mask; + auto z = vabs * vabs; + auto y = z.madd(tanh_p0, tanh_p1); + auto tmp = y.madd(z, tanh_p2); + y = z.madd(tmp, tanh_p3); + tmp = y.madd(z, tanh_p4); + y = tmp * z; + tmp = y.madd(x, x); + // add sign + vv = vv | sign_bit; + // check and select + auto sel_mask = vabs >= tanh_0p625; + auto max_mask = vabs > tanh_half_max; + auto max_ret = sign_bit ^ one; + return blendv(blendv(tmp, vv, sel_mask), max_ret, max_mask); + } + Vec256 C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vec256 C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE reciprocal() const { + return Vec256(one) / (*this); + } + Vec256 C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vec256 C10_ALWAYS_INLINE pow(const Vec256& exp) const { + auto x = *this; + auto sign_bit = (*this) & sign_mask; + // |b| + auto exp_abs = exp.abs(); + auto exp_trunc = exp.trunc(); + Vec256 odd_mask; + odd_mask._vecb0 = (vec_signed(exp._vec0) & vi_1) != vi_0; + odd_mask._vecb1 = (vec_signed(exp._vec1) & vi_1) != vi_0; + // using ln fuction + auto temp = (abs().log() * exp).exp(); + + // is odd or even check from Sleef + auto is_int = (exp == exp_trunc) | (exp_abs >= vcheck); + auto is_odd = odd_mask & is_int & (exp_abs < vcheck); + // if even then then pow result should be absolute + auto temp_sign = temp | sign_bit; // copy_sign + auto out = blendv(temp, temp_sign, is_odd); + // x<0 and y != N, then NAN + auto out1 = blendv(out, v_nan, ((exp.floor() != exp) & (x < zero))); + // y = 0 then 1 + return blendv(out1, one, (exp_abs == zero)); + } + + Vec256 fmod(const Vec256& b) const { + return {Sleef_fmodf4_vsx(_vec0, b._vec0),Sleef_fmodf4_vsx(_vec1, b._vec1)}; + } + + Vec256 hypot(const Vec256& b) const { + return {Sleef_hypotf4_u05vsx(_vec0, b._vec0), Sleef_hypotf4_u05vsx(_vec1, b._vec1)}; + } + + Vec256 nextafter(const Vec256& b) const { + return {Sleef_nextafterf4_vsx(_vec0, b._vec0), Sleef_nextafterf4_vsx(_vec1, b._vec1)}; + } + + Vec256 igamma(const Vec256& x) const { + return mapbi(calc_igamma, x); + } + + Vec256 igammac(const Vec256& x) const { + return mapbi(calc_igammac, x); + } + + Vec256 i0() const { + return map(calc_i0); + } + + DEFINE_MEMBER_OP(operator==, float, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, float, vec_cmpne) + DEFINE_MEMBER_OP(operator<, float, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, float, vec_cmple) + DEFINE_MEMBER_OP(operator>, float, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, float, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge) + DEFINE_MEMBER_OP(operator+, float, vec_add) + DEFINE_MEMBER_OP(operator-, float, vec_sub) + DEFINE_MEMBER_OP(operator*, float, vec_mul) + DEFINE_MEMBER_OP(operator/, float, vec_div) + DEFINE_MEMBER_OP(maximum, float, vec_max) + DEFINE_MEMBER_OP(minimum, float, vec_min) + DEFINE_MEMBER_OP(operator&, float, vec_and) + DEFINE_MEMBER_OP(operator|, float, vec_or) + DEFINE_MEMBER_OP(operator^, float, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd) +}; + +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum(const Vec256& a, const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_int16_vsx.h new file mode 100644 index 000000000000..33460abe2a58 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_int16_vsx.h @@ -0,0 +1,351 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vint16 _vec0; + vint16 _vec1; + }; + struct { + vbool16 _vecb0; + vbool16 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int16_t; + using vec_internal_type = vint16; + using vec_internal_mask_type = vbool16; + static constexpr int size() { + return 16; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vint16 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool16 v1, vbool16 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(int16_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + + C10_ALWAYS_INLINE Vec256( + int16_t scalar1, + int16_t scalar2, + int16_t scalar3, + int16_t scalar4, + int16_t scalar5, + int16_t scalar6, + int16_t scalar7, + int16_t scalar8, + int16_t scalar9, + int16_t scalar10, + int16_t scalar11, + int16_t scalar12, + int16_t scalar13, + int16_t scalar14, + int16_t scalar15, + int16_t scalar16) + : _vec0{vint16{ + scalar1, + scalar2, + scalar3, + scalar4, + scalar5, + scalar6, + scalar7, + scalar8}}, + _vec1{vint16{ + scalar9, + scalar10, + scalar11, + scalar12, + scalar13, + scalar14, + scalar15, + scalar16}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t<(mask & 65535) == 65535, Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 255), Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + + return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) && + ((mask & 255) != 255)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return { + (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), + (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + static Vec256 arange(int16_t base = 0, int16_t step = 1) { + return Vec256( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not) + DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int16_t, vec_add) + DEFINE_MEMBER_OP(operator-, int16_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int16_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /) + DEFINE_MEMBER_OP(maximum, int16_t, vec_max) + DEFINE_MEMBER_OP(minimum, int16_t, vec_min) + DEFINE_MEMBER_OP(operator&, int16_t, vec_and) + DEFINE_MEMBER_OP(operator|, int16_t, vec_or) + DEFINE_MEMBER_OP(operator^, int16_t, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_int32_vsx.h new file mode 100644 index 000000000000..2ee2318f0349 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_int32_vsx.h @@ -0,0 +1,281 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int32_t; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + static constexpr int size() { + return 8; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(int32_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + int32_t scalar1, + int32_t scalar2, + int32_t scalar3, + int32_t scalar4, + int32_t scalar5, + int32_t scalar6, + int32_t scalar7, + int32_t scalar8) + : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t<(mask & 255) == 255, Vec256> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 15), Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + + return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) && + ((mask & 15) != 15)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return { + (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), + (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + static Vec256 arange(int32_t base = 0.f, int32_t step = 1.f) { + return Vec256( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not) + DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int32_t, vec_add) + DEFINE_MEMBER_OP(operator-, int32_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int32_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /) + DEFINE_MEMBER_OP(maximum, int32_t, vec_max) + DEFINE_MEMBER_OP(minimum, int32_t, vec_min) + DEFINE_MEMBER_OP(operator&, int32_t, vec_and) + DEFINE_MEMBER_OP(operator|, int32_t, vec_or) + DEFINE_MEMBER_OP(operator^, int32_t, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_int64_vsx.h new file mode 100644 index 000000000000..d752f71c9a63 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_int64_vsx.h @@ -0,0 +1,233 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vint64 _vec0; + vint64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int64_t; + using vec_internal_type = vint64; + using vec_internal_mask_type = vbool64; + static constexpr int size() { + return 4; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vint64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(int64_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + int64_t scalar1, + int64_t scalar2, + int64_t scalar3, + int64_t scalar4) + : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask & 15) == 15, Vec256> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t<(mask > 0 && mask < 3), Vec256> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + const vbool64 mask_1st = (vbool64){g0, g1}; + return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1}; + } + + template + static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15, + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_1st = (vbool64){g0, g1}; + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return { + (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), + (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + static Vec256 arange(int64_t base = 0., int64_t step = 1.) { + return Vec256(base, base + step, base + 2 * step, base + 3 * step); + } + + static Vec256 C10_ALWAYS_INLINE + set(const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + static_assert(sizeof(double) == sizeof(value_type)); + const double* dptr = reinterpret_cast(ptr); + return {// treat it as double load + (vint64)vec_vsx_ld(offset0, dptr), + (vint64)vec_vsx_ld(offset16, dptr)}; + } + + __at_align32__ double tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + (vint64)vec_vsx_ld(offset0, tmp_values), + (vint64)vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + double* dptr = reinterpret_cast(ptr); + vec_vsx_st((vfloat64)_vec0, offset0, dptr); + vec_vsx_st((vfloat64)_vec1, offset16, dptr); + } else if (count > 0) { + __at_align32__ double tmp_values[size()]; + vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); + vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not) + DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int64_t, vec_add) + DEFINE_MEMBER_OP(operator-, int64_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int64_t, vec_mul) + DEFINE_MEMBER_OP(operator/, int64_t, vec_div) + DEFINE_MEMBER_OP(maximum, int64_t, vec_max) + DEFINE_MEMBER_OP(minimum, int64_t, vec_min) + DEFINE_MEMBER_OP(operator&, int64_t, vec_and) + DEFINE_MEMBER_OP(operator|, int64_t, vec_or) + DEFINE_MEMBER_OP(operator^, int64_t, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_qint32_vsx.h new file mode 100644 index 000000000000..a47e295ce03b --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_qint32_vsx.h @@ -0,0 +1,242 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vec256<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vec256, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vec256 -> 1x Vec256 +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vec256::float_num_vecs +// iterations. + +namespace at { +namespace vec256 { +namespace { + +template <> +struct Vec256 { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vec256() {} + + static constexpr int size() { + return 8; + } + + static constexpr size_t float_num_vecs() { + return 1; + } + static constexpr int int_num_vecs() { + return 1; + } + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + C10_ALWAYS_INLINE Vec256(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + + Vec256(const c10::qint32& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + float_vec_return_type dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); + vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); + return {Vec256{ + vec_madd(scale_vec0, float_vals0, scale_zp_premul0), + vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}}; + } + + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vec256 retval; + + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)(zero_point)); + Vec256 vf0 = rhs[0]; + + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + Vec256 relu(Vec256 zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vec256 relu6( + Vec256 zero_point, + Vec256 q_six) const { + vint32 max0 = vec_max(_vec0, zero_point._vec0); + vint32 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vec256 b) const { + return {*this - b}; + } + + static Vec256 requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 vec_mult = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + Vec256 vi = inp[0]; + vfloat32 vecf0 = vec_float(vi.vec0()); + vfloat32 vecf1 = vec_float(vi.vec1()); + + vecf0 = vec_mul(vecf0, vec_mult); + vecf1 = vec_mul(vecf1, vec_mult); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + + vint32 veci0 = vec_add(vec_signed(vecf0),vec_zero_point); + vint32 veci1 = vec_add(vec_signed(vecf1),vec_zero_point); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + void dump() const { + std::cout << _vec0[0] << " "; + std::cout << _vec0[1] << " "; + std::cout << _vec0[2] << " "; + std::cout << _vec0[3] << " "; + std::cout << _vec1[0] << " "; + std::cout << _vec1[1] << " "; + std::cout << _vec1[2] << " "; + std::cout << _vec1[3] << " "; + std::cout << std::endl; + } + + DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /) + DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_qint8_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_qint8_vsx.h new file mode 100644 index 000000000000..f8b6eced60ef --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_qint8_vsx.h @@ -0,0 +1,404 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vec256<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vec256, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vec256 -> 4x Vec256 +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vec256::float_num_vecs +// iterations. + +namespace at { +namespace vec256 { +namespace { + +template <> +struct Vec256 { + private: + union { + struct { + vint8 _vec0; + vint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vec256() {} + static constexpr int size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + using vec_internal_type = vint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vec256(const c10::qint8& val) + : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {} + + C10_ALWAYS_INLINE Vec256(const Vec256& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vec256(vint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vec256 loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); + vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); + return { + Vec256{ + vec_madd(scale_vec0, vecf0_0, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_1, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_2, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_3, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}}; + } + + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vint32 vmin = vec_splats(min_val); + // vint32 vmax = vec_splats(max_val); + + Vec256 vf0 = rhs[0]; + Vec256 vf1 = rhs[1]; + Vec256 vf2 = rhs[2]; + Vec256 vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf2 = vec_mul(vecf2, inverse_scale_v); + vecf3 = vec_mul(vecf3, inverse_scale_v); + + vecf4 = vec_mul(vecf4, inverse_scale_v); + vecf5 = vec_mul(vecf5, inverse_scale_v); + vecf6 = vec_mul(vecf6, inverse_scale_v); + vecf7 = vec_mul(vecf7, inverse_scale_v); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ; + // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ; + // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ; + // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ; + + // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ; + // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ; + // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ; + // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ; + // vec_packs CLAMP already + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vec256 C10_ALWAYS_INLINE relu(Vec256 zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vec256 C10_ALWAYS_INLINE + relu6(Vec256 zero_point, Vec256 q_six) const { + vint8 max0 = vec_max(_vec0, zero_point._vec0); + vint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vec256 b) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecBshi0 = vec_unpackh(b._vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + vint16 vecBshi1 = vec_unpackl(b._vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecBshi2 = vec_unpackh(b._vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + vint16 vecBshi3 = vec_unpackl(b._vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vec256(veci0 - vecBi0, veci1 - vecBi1), + Vec256(veci2 - vecBi2, veci3 - vecBi3), + Vec256(veci4 - vecBi4, veci5 - vecBi5), + Vec256(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vec256 requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vec256 vi0 = inp[0]; + Vec256 vi1 = inp[1]; + Vec256 vi2 = inp[2]; + Vec256 vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + void dump() const { + value_type vals[size()]; + store((void*)vals); + for (int i = 0; i < size(); ++i) { + std::cout << (int)(vals[i]) << " "; + } + std::cout << std::endl; + } + + DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /) + DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_quint8_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_quint8_vsx.h new file mode 100644 index 000000000000..96809ce32593 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_quint8_vsx.h @@ -0,0 +1,413 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vec256<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vec256, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vec256 -> 4x Vec256 +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vec256::float_num_vecs +// iterations. + +namespace at { +namespace vec256 { +namespace { + +const vint16 mask_unsigned = vec_splats((short int)0xFF); +template <> +struct Vec256 { + private: + union { + struct { + vuint8 _vec0; + vuint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vec256() {} + static constexpr int size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + using vec_internal_type = vuint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vec256(const c10::quint8& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + C10_ALWAYS_INLINE Vec256(const Vec256& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vec256(vuint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vec256 loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); + vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); + return { + Vec256{ + vec_madd(scale_vec0, vecf0_0, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_1, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_2, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_3, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}}; + } + + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 vec_inverse = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vuint32 vmin = vec_splats(min_val); + // vuint32 vmax = vec_splats(max_val); + Vec256 vf0 = rhs[0]; + Vec256 vf1 = rhs[1]; + Vec256 vf2 = rhs[2]; + Vec256 vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, vec_inverse); + vecf1 = vec_mul(vecf1, vec_inverse); + vecf2 = vec_mul(vecf2, vec_inverse); + vecf3 = vec_mul(vecf3, vec_inverse); + + vecf4 = vec_mul(vecf4, vec_inverse); + vecf5 = vec_mul(vecf5, vec_inverse); + vecf6 = vec_mul(vecf6, vec_inverse); + vecf7 = vec_mul(vecf7, vec_inverse); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vec256 C10_ALWAYS_INLINE relu(Vec256 zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vec256 C10_ALWAYS_INLINE + relu6(Vec256 zero_point, Vec256 q_six) const { + vuint8 max0 = vec_max(_vec0, zero_point._vec0); + vuint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vec256 b) const { + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecBshi0 = vec_unpackh((vint8)b._vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + vint16 vecBshi1 = vec_unpackl((vint8)b._vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecBshi2 = vec_unpackh((vint8)b._vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + vint16 vecBshi3 = vec_unpackl((vint8)b._vec1); + + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecBshi0 = vec_and(vecBshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + vecBshi1 = vec_and(vecBshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecBshi2 = vec_and(vecBshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + vecBshi3 = vec_and(vecBshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vec256(veci0 - vecBi0, veci1 - vecBi1), + Vec256(veci2 - vecBi2, veci3 - vecBi3), + Vec256(veci4 - vecBi4, veci5 - vecBi5), + Vec256(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vec256 requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vec256 vi0 = inp[0]; + Vec256 vi1 = inp[1]; + Vec256 vi2 = inp[2]; + Vec256 vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + void dump() const { + value_type vals[size()]; + store((void*)vals); + for (int i = 0; i < size(); ++i) { + std::cout << (int)(vals[i]) << " "; + } + std::cout << std::endl; + } + + DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /) + DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vsx_helpers.h b/aten/src/ATen/cpu/vec256/vsx/vsx_helpers.h new file mode 100644 index 000000000000..40cb7ef7a66e --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vsx_helpers.h @@ -0,0 +1,332 @@ +#pragma once +#include +#include +#include + +using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; +using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; +using vbool32 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int; +using vbool64 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) long long; +using vint8 = __attribute__((altivec(vector__))) signed char; +using vint16 = __attribute__((altivec(vector__))) signed short; +using vint32 = __attribute__((altivec(vector__))) signed int; +using vint64 = __attribute__((altivec(vector__))) signed long long; +using vuint8 = __attribute__((altivec(vector__))) unsigned char; +using vuint16 = __attribute__((altivec(vector__))) unsigned short; +using vuint32 = __attribute__((altivec(vector__))) unsigned int; +using vuint64 = __attribute__((altivec(vector__))) unsigned long long; +using vfloat32 = __attribute__((altivec(vector__))) float; +using vfloat64 = __attribute__((altivec(vector__))) double; + +#if !defined(vec_float) +C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { + vfloat32 vec_out; + __asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in)); + return vec_out; +} +#endif + +#define vec_not(a) vec_nor(a, a) + +#define DEFINE_MEMBER_UNARY_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op() const { \ + return Vec256{func(_vec0), func(_vec1)}; \ + } + +#define DEFINE_MEMBER_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& other) const { \ + return Vec256{ \ + func(_vec0, other._vec0), func(_vec1, other._vec1)}; \ + } + +#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& other) const { \ + return Vec256{ \ + func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)}; \ + } + +#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op( \ + const Vec256& b, const Vec256& c) const { \ + return Vec256{ \ + func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \ + } + +#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& b) const { \ + Vec256::vec_internal_type ret_0; \ + Vec256::vec_internal_type ret_1; \ + for (int i = 0; i < Vec256::size() / 2; i++) { \ + ret_0[i] = _vec0[i] binary_op b._vec0[i]; \ + ret_1[i] = _vec1[i] binary_op b._vec1[i]; \ + } \ + return Vec256{ret_0, ret_1}; \ + } + + +#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& other) const { \ + using vvtype = Vec256::vec_internal_type; \ + const vvtype v_one = vec_splats(static_cast(1.0)); \ + vvtype ret0 = (vvtype)func(_vec0, other._vec0); \ + vvtype ret1 = (vvtype)func(_vec1, other._vec1); \ + return Vec256{vec_and(ret0, v_one), vec_and(ret1, v_one)}; \ + } + +#define DEFINE_CLAMP_FUNCS(operand_type) \ + template <> \ + Vec256 C10_ALWAYS_INLINE clamp( \ + const Vec256& a, \ + const Vec256& min, \ + const Vec256& max) { \ + return Vec256{ \ + vec_min(max.vec0(), vec_max(a.vec0(), min.vec0())), \ + vec_min(max.vec1(), vec_max(a.vec1(), min.vec1()))}; \ + } \ + template <> \ + Vec256 C10_ALWAYS_INLINE clamp_min( \ + const Vec256& a, const Vec256& min) { \ + return Vec256{ \ + vec_max(a.vec0(), min.vec0()), vec_max(a.vec1(), min.vec1())}; \ + } \ + template <> \ + Vec256 C10_ALWAYS_INLINE clamp_max( \ + const Vec256& a, const Vec256& max) { \ + return Vec256{ \ + vec_min(a.vec0(), max.vec0()), vec_min(a.vec1(), max.vec1())}; \ + } + +#define DEFINE_REINTERPRET_CAST_FUNCS( \ + first_type, cast_type, cast_inner_vector_type) \ + template <> \ + C10_ALWAYS_INLINE Vec256 cast( \ + const Vec256& src) { \ + return Vec256{(cast_inner_vector_type)src.vec0(), \ + (cast_inner_vector_type)src.vec1()}; \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16) + +// it can be used to emulate blend faster +constexpr int blendChoice(uint32_t mask, uint32_t half1 = 0xF, uint32_t half2 = 0xF0) { + uint32_t none = 0; + uint32_t both = half1 | half2; + // clamp it between 0 and both + mask = mask & both; + // return (a._vec0, a._vec1) + if (mask == none) return 0; + // return (b._vec0,b._vec1) + else if (mask == both) + return 1; + // return (b._vec0,a._vec1) + else if (mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (mask > 0 && mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((mask & half1) == 0 && mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((mask & half1) == half1 && mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +constexpr int blendChoiceDbl(uint32_t mask) { + // clamp it 0 and 0xF + return blendChoice(mask, 0x3, 0xC); +} + +constexpr vbool32 VsxMask1(uint32_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (vbool32){g0, g1, g2, g3}; +} + +constexpr vbool32 VsxMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xFF) >> 4; + return VsxMask1(mask2); +} + +constexpr vbool64 VsxDblMask1(uint32_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (vbool64){g0, g1}; +} + +constexpr vbool64 VsxDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +constexpr int maskForComplex(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) complex_mask |= 3; + if (mask & 2) complex_mask |= (3 << 2); + if (mask & 4) complex_mask |= (3 << 4); + if (mask & 8) complex_mask |= (3 << 6); + return complex_mask; +} + +constexpr int maskForComplexDbl(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) complex_mask |= 3; + if (mask & 2) complex_mask |= (3 << 2); + return complex_mask; +} + +constexpr int blendChoiceComplex(uint32_t mask) { + return blendChoice(maskForComplex(mask)); +} + +constexpr int blendChoiceComplexDbl(uint32_t mask) { + return blendChoiceDbl(maskForComplexDbl(mask)); +} + +constexpr vbool32 VsxComplexMask1(uint32_t mask) { + return VsxMask1(maskForComplex(mask)); +} + +constexpr vbool32 VsxComplexMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxMask1(maskForComplex(mask2)); +} + +constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { return VsxDblMask1(mask); } + +constexpr vbool64 VsxComplexDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +// constants +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { +// + constexpr int offset0 = 0; + constexpr int offset16 = 16; + +//#Constants +const vuint8 mask_zero_bits = vuint8{128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 96, 64, 32, 0}; + +const vuint8 swap_mask = + vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; + +const vint32 v0x7f = vec_splats(0x7f); +const vint32 vi_0 = vec_splats((int)(0)); +const vint32 vi_1 = vec_splats((int)1); +const vint32 vi_2 = vec_splats((int)2); +const vint32 vi_4 = vec_splats((int)4); +const vint32 vi_inv1 = vec_splats((int)~1); +const vuint32 vu_29 = vec_splats(29u); +const vuint32 vu_23 = vec_splats(23u); + +const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000); +const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000); +const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0}; +const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF}; +const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000}; +const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0}; + +const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF}; +const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0}; +const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000}; +const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0}; + +const vfloat32 zero = vec_splats(0.f); +const vfloat32 half = vec_splats(0.5f); +const vfloat32 one = vec_splats(1.f); +const vfloat32 two = vec_splats(2.0f); +const vfloat32 _4div_pi = vec_splats(1.27323954473516f); +const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u); +const vfloat32 v_minus_inf = vfloat32{ 0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u }; +const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff); +const vfloat32 log10e_inv = vec_splats(0.43429448190325176f); +const vfloat32 log2e_inv = vec_splats(1.4426950408889634f); +const vfloat32 log2eB_inv = vec_splats(1.442695036924675f); +const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f); +const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f); +const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f); +const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f); +const vfloat32 exp_hi = vec_splats(104.f); +const vfloat32 exp_lo = vec_splats(-104.f); +const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f); +const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f)); +const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f); +const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f); +const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f); +const vfloat32 exp_p5 = vec_splats(0.5f); +const vfloat32 log_p0 = vec_splats(7.0376836292E-2f); +const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f); +const vfloat32 log_p2 = vec_splats(1.1676998740E-1f); +const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f); +const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f); +const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f); +const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f); +const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f); +const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f); +const vfloat32 log_q1 = vec_splats(-2.12194440e-4f); +const vfloat32 log_q2 = vec_splats(0.693359375f); +const vfloat32 max_logf = vec_splats(88.02969187150841f); +const vfloat32 max_numf = vec_splats(1.7014117331926442990585209174225846272e38f); +const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u); +const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u); +const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f); +const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f); +const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f); +const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f); +const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f); +const vfloat32 p0 = vec_splats(2.03721912945E-4f); +const vfloat32 p1 = vec_splats(8.33028376239E-3f); +const vfloat32 p2 = vec_splats(1.66667160211E-1f); +const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f); +const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f); +const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f); +const vfloat32 tanh_0p625 = vec_splats(0.625f); +const vfloat32 tanh_half_max = vec_splats(44.014845935754205f); +const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f); +const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f); +const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f); +const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f); +const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f); +const vfloat32 vcheck = vec_splats((float)(1LL << 24)); +const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f}; +const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f}; +const vfloat32 sqrt2_2 = vfloat32{0.70710676908493042f, 0.70710676908493042, + 0.70710676908493042, 0.70710676908493042}; +const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0}; +const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f}; +const vfloat64 vd_one = vec_splats(1.0); +const vfloat64 vd_zero = vec_splats(0.0); +const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176); +const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634); +const vfloat64 vd_imag_one = vfloat64{0.0, 1.0}; +const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; +const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; +const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; + +} // namespace +} // namespace vec256 +} // namespace at + diff --git a/aten/src/ATen/cuda/CUDAFuture.h b/aten/src/ATen/cuda/CUDAFuture.h new file mode 100644 index 000000000000..ae43fb2a2dd6 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAFuture.h @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace cuda { + +struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future { + public: + using at::ivalue::Future::Future; + + void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override { + std::unique_lock lock(dataPtrExtractorMutex_); + dataPtrExtractor_ = std::move(dataPtrExtractor); + } + + protected: + c10::intrusive_ptr createInstance(at::TypePtr type) override { + auto fut = c10::make_intrusive(std::move(type)); + // The new future needs the DataPtr extractor when it gets marked complete + // but this might happen immediately inline or in parallel by another + // thread. In both these cases this would/might happen before the user has + // time to set their own DataPtr extractor, which might lead to failures + // if the default extractor can't handle some of the user's types. + // Therefore we propagate our extractor. + fut->setDataPtrExtractor(dataPtrExtractor_); + return fut; + } + + void postMarkCompletedHook(const at::IValue& value) override { + currentDevice_ = c10::cuda::current_device(); + + // Extract them once and cache them for later uses. + dataPtrs_ = extractDataPtrs(value); + + std::vector isCudaDeviceUsed(c10::cuda::device_count(), false); + for (const at::DataPtr& data_ptr : dataPtrs_) { + if (data_ptr.device().is_cuda()) { + isCudaDeviceUsed[data_ptr.device().index()] = true; + } + } + + for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) { + if (isCudaDeviceUsed[idx]) { + at::cuda::CUDAEvent cudaEvent; + cudaEvent.record(at::cuda::getCurrentCUDAStream(idx)); + cudaEvents_.push_back(std::move(cudaEvent)); + } + } + } + + std::function wrapCallback( + std::function callback) override { + return [this, callback{std::move(callback)}]() { + // We'd love to get a stream for all devices, even those that are not used + // by the value, because the callback could use those other devices, but + // unfortunately this could cause a deadlock with NCCL. See + // https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414 + // In general, if some devices haven't been used yet, by getting a stream + // for them we'd initialize them, and in addition to causing NCCL to + // misbehaving this also ends up using memory on those devices, which the + // user might not want. + std::vector streams; + for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) { + c10::DeviceIndex idx = cudaEvent.device_index(); + // FIXME Should we find a way to allow to change the priority of + // streams? + at::cuda::CUDAStream stream = + at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx); + cudaEvent.block(stream); + streams.push_back(stream); + } + + // Use the dedicated callback stream to run callback. + at::cuda::CUDAMultiStreamGuard streamGuard(streams); + + // Do not free the underlying data storage of value_ before its + // usage on the stream finishes. + for (const at::DataPtr& data_ptr : dataPtrs_) { + if (data_ptr.device().is_cuda()) { + c10::cuda::CUDACachingAllocator::recordStream( + data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); + } + } + + c10::cuda::CUDAGuard deviceGuard(currentDevice_); + + callback(); + }; + } + + void postWaitHook(const at::IValue& value) override { + for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) { + cudaEvent.block( + at::cuda::getCurrentCUDAStream(cudaEvent.device_index())); + } + + for (const at::DataPtr& data_ptr : dataPtrs_) { + if (data_ptr.device().is_cuda()) { + c10::cuda::CUDACachingAllocator::recordStream( + data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); + } + } + } + + private: + // The device that was current when markCompleted was called, which we'll + // restore when invoking callbacks. + c10::DeviceIndex currentDevice_; + + // The events that correspond to the completion of the async I/O kernels. They + // are recorded on the appropriate streams when the future is marked completed + // and can then be queried/waited/blocked on. There is one event for each + // distinct device on which the value's tensors reside. + std::vector cudaEvents_; + + // A cached version of the data ptrs extracted from the value when the future + // is first marked completed. + std::vector> dataPtrs_; + + DataPtrExtractor dataPtrExtractor_; + std::mutex dataPtrExtractorMutex_; + + std::vector> extractDataPtrs( + const at::IValue& value) { + std::unique_lock lock(dataPtrExtractorMutex_); + std::vector> data_ptrs; + if (dataPtrExtractor_ != nullptr) { + // If a Python communication hook is used, dataPtrExtractor_ will be + // set in torch/csrc/jit/python/pybind_utils.h, which allows Python + // dependency to be imported. + data_ptrs = dataPtrExtractor_(value); + } else { + // If a C++ communication hook is used, use the default extractor. + data_ptrs = at::ivalue::Future::defaultDataPtrExtractor(value); + } + return data_ptrs; + } +}; + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index f0db9014163a..8a5e4f48e0c0 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -84,7 +84,7 @@ Generator createCUDAGenerator(DeviceIndex device_index) { */ CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index) : c10::GeneratorImpl{Device(DeviceType::CUDA, device_index), - DispatchKeySet(c10::DispatchKey::CUDA)} { + DispatchKeySet(c10::DispatchKey::CUDA)} { at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl"); } @@ -101,20 +101,18 @@ void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { } #define CAPTURE_DEFAULT_GENS_MSG \ -"Non-default (user-constructed) CUDA RNG generators cannot be used " \ -"in regions captured by CUDA graphs. " \ -"If you need a non-default CUDA generator in a captured region, " \ -"please file an issue." +"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \ +"generator on the device that's current when capture begins. " \ +"If you need a non-default (user-supplied) generator, or a generator on another " \ +"device, please file an issue." /** * Gets the current seed of CUDAGeneratorImpl. */ uint64_t CUDAGeneratorImpl::current_seed() const { - TORCH_CHECK((at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) || - ((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index())), - CAPTURE_DEFAULT_GENS_MSG); + // Debatable if current_seed() should be allowed in captured regions. + // Conservatively disallow it for now. + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed"); return seed_; } @@ -151,25 +149,21 @@ uint64_t CUDAGeneratorImpl::philox_offset_per_thread() { } /** - * Prepares this instance for a cuda graph capture region. + * Called by CUDAGraph to prepare this instance for a graph capture region. * offset_extragraph is the initial offset at the start of the graphed region. * offset_intragraph tracks the offset in the graphed region. */ -void CUDAGeneratorImpl::graph_prologue(int64_t* offset_extragraph) { - TORCH_CHECK((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()), - CAPTURE_DEFAULT_GENS_MSG); +void CUDAGeneratorImpl::capture_prologue(int64_t* offset_extragraph) { offset_extragraph_ = offset_extragraph; offset_intragraph_ = 0; + graph_expects_this_gen_ = true; } /** - * Finalizes a cuda graph capture region for this instance. + * Called by CUDAGraph to finalize a graph capture region for this instance. */ -uint64_t CUDAGeneratorImpl::graph_epilogue() { - TORCH_CHECK((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()), - CAPTURE_DEFAULT_GENS_MSG); +uint64_t CUDAGeneratorImpl::capture_epilogue() { + graph_expects_this_gen_ = false; return offset_intragraph_; } @@ -187,7 +181,7 @@ uint64_t CUDAGeneratorImpl::graph_epilogue() { * it intends to generate. * * Increment should be at least the number of curand() random numbers used in - * each thread. It is the user's responsibility to make sure that the increment + * each thread. It is the user's responsibility to make sure the increment * for philox is never smaller than the number of curand() calls. Increment * value > the number of curand() calls won't harm but anything less would mean * that you would be reusing random values from previous calls. @@ -196,17 +190,20 @@ uint64_t CUDAGeneratorImpl::graph_epilogue() { */ PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) { if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - TORCH_CHECK((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()), + TORCH_CHECK(graph_expects_this_gen_, + "philox_cuda_state for an unexpected CUDA generator used during capture. " CAPTURE_DEFAULT_GENS_MSG); uint32_t offset = this->offset_intragraph_; TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <= - std::numeric_limits::max() - increment); + std::numeric_limits::max() - increment); this->offset_intragraph_ += increment; return PhiloxCudaState(this->seed_, this->offset_extragraph_, offset); } else { + TORCH_CHECK(!graph_expects_this_gen_, + "CUDA generator expects graph capture to be underway, " + "but the current stream is not capturing."); uint64_t offset = this->philox_offset_per_thread_; this->philox_offset_per_thread_ += increment; return PhiloxCudaState(this->seed_, offset); diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp new file mode 100644 index 000000000000..74cc5ca09793 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -0,0 +1,168 @@ +#include +#include +#include +#include + +namespace at { +namespace cuda { + +/** + * Note [CUDA Graph Wrapper Class] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Q: Why do we need graph capture and launch bindings in Pytorch? + * Why can't they live in a user extension, for example? + * + * A1: Convenience. + * A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with + * CPU statefulness) need cooperation from the capture and replay bindings + * (see Note [CUDA Graph-safe RNG states] in CUDAGeneratorImpl.h). + * + * We can't expect users to know about this cooperation. If users write capture + * bindings naively in an extension, they likely won't interact with the native + * ops properly. Their graphs would yield invalid numerics on replay. + */ + +CUDAGraph::CUDAGraph() + // CUDAStreams may not be default-constructed. + : capture_stream_(at::cuda::getCurrentCUDAStream()) { +#if CUDA_VERSION < 11000 + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::capture_begin() { +#if CUDA_VERSION >= 11000 + TORCH_CHECK(!has_graph_exec_, + "This CUDAGraph instance already owns a captured graph. " + "To capture a new graph, create a new instance."); + + // For now, a CUDAGraph instance only accommodates the default generator on the device that's + // current when capture begins. If any op in the captured region uses a non-default generator, + // or a generator on another device, the offending generator will throw an error. + // These restrictions simplify CUDAGraph, but could be relaxed in the future: + // in principle, the underlying Cuda calls do permit cross-device ops to be captured. + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + + auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong); + offset_extragraph_ = at::empty({1}, options); + + gen->capture_prologue(offset_extragraph_.data_ptr()); + + auto stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(), + "CUDA graphs must be captured on a non-default stream. " + "(However, after capture, it's ok to replay them on the " + "default stream.)"); + + capture_stream_ = stream; + capture_gen_ = gen; + + // cudaStreamCaptureModeGlobal is the most conservative option to + // prevent potentially unsafe CUDA API calls during capture. See + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); + + // Stashes the current graph's uuid. + cudaStreamCaptureStatus status; + AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id_)); + TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::capture_end() { +#if CUDA_VERSION >= 11000 + auto stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(stream == capture_stream_, + "Capture must end on the same stream it began on."); + + AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); + TORCH_CHECK(graph_ != NULL, "Invalid capture."); + has_graph_ = true; + + // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, + // who prefer not to report error message through these arguments moving forward + // (they prefer return value, or errors on api calls internal to the capture) + AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); + has_graph_exec_ = true; + + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + TORCH_CHECK(gen == capture_gen_, + "Default CUDA RNG generator on current device at capture end " + "is different from default generator on current device " + "when capture began"); + wholegraph_increment_ = gen->capture_epilogue(); + + // Now that we've instantiated graph_ into graph_exec_, + // we don't need graph_ anymore. + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + has_graph_ = false; +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::replay() { +#if CUDA_VERSION >= 11000 + TORCH_CHECK(has_graph_exec_, + "Called CUDAGraph::replay without a preceding successful capture."); + + { + c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; + + // Just like any RNG consumer kernel! + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + PhiloxCudaState rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_); + } + offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val)); + + // graph_exec_ may be replayed in any stream. + AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream())); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::reset() { +#if CUDA_VERSION >= 11000 + // I'd prefer these checks throw exceptions, not print warnings, + // but the destructor calls reset(), and at least one CI build + // refuses to compile with a throwing destructor. + // + // Instead of calling reset() in the destructor to clean up, I could + // call reset() in the __del__ method of a thin Python wrapper, + // in which case reset would be allowed to throw exceptions. + // But Stackoverflow does not like user-defined __del__. + // __del__ prevents Graph instances from EVER being garbage collected + // if they participate in a reference cycle. + // And exceptions thrown in __del__ only print a warning anyway. + // + // Calling reset() in the C++ destructor, with warnings instead of exceptions + // if calls fail, is the compromise we chose. + if (has_graph_) { + C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_)); + } + if (has_graph_exec_) { + C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_)); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +CUDAGraph::~CUDAGraph() { + reset(); +} + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h new file mode 100644 index 000000000000..387271715055 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -0,0 +1,43 @@ +#include +#include +#include +#include + +namespace at { +namespace cuda { + +struct TORCH_CUDA_API CUDAGraph { + CUDAGraph(); + ~CUDAGraph(); + + void capture_begin(); + void capture_end(); + void replay(); + void reset(); + + protected: +#if CUDA_VERSION >= 11000 + cudaGraph_t graph_ = NULL; + cudaGraphExec_t graph_exec_ = NULL; +#endif + + // internal states for error checking + bool has_graph_ = false; + bool has_graph_exec_ = false; + + // uuid, retrieved from Cuda + unsigned long long id_; + + // Stream on which capture began + at::cuda::CUDAStream capture_stream_; + + // Default generator on device where capture began + at::CUDAGeneratorImpl* capture_gen_; + + // RNG state trackers + at::Tensor offset_extragraph_; + uint64_t wholegraph_increment_; +}; + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDASolver.cpp b/aten/src/ATen/cuda/CUDASolver.cpp index 00329acda4a9..bcd630a06b9e 100644 --- a/aten/src/ATen/cuda/CUDASolver.cpp +++ b/aten/src/ATen/cuda/CUDASolver.cpp @@ -46,14 +46,14 @@ void getrf>( TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize( handle, m, n, reinterpret_cast(dA), ldda, &lwork)); auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - void* buffer = allocator.allocate(sizeof(cuDoubleComplex) * lwork).get(); + auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex) * lwork); TORCH_CUSOLVER_CHECK(cusolverDnZgetrf( handle, m, n, reinterpret_cast(dA), ldda, - static_cast(buffer), + static_cast(dataPtr.get()), ipiv, info)); } @@ -71,14 +71,14 @@ void getrf>( TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize( handle, m, n, reinterpret_cast(dA), ldda, &lwork)); auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - void* buffer = allocator.allocate(sizeof(cuComplex) * lwork).get(); + auto dataPtr = allocator.allocate(sizeof(cuComplex) * lwork); TORCH_CUSOLVER_CHECK(cusolverDnCgetrf( handle, m, n, reinterpret_cast(dA), ldda, - static_cast(buffer), + static_cast(dataPtr.get()), ipiv, info)); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 28b9738034e7..00424ab83ba0 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -163,7 +163,9 @@ bool CUDAHooks::hasPrimaryContext(int64_t device_index) const { TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), "hasPrimaryContext expects a valid device index, but got device_index=", device_index); unsigned int ctx_flags; - int ctx_is_active; + // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird + // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero. + int ctx_is_active = 0; AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active)); return ctx_is_active == 1; } diff --git a/aten/src/ATen/miopen/Handle.cpp b/aten/src/ATen/miopen/Handle.cpp index 8965ef5a2cce..6b8c7c6421c4 100644 --- a/aten/src/ATen/miopen/Handle.cpp +++ b/aten/src/ATen/miopen/Handle.cpp @@ -1,39 +1,53 @@ -#include - #include - -#include -#include +#include +#include +#include namespace at { namespace native { - namespace { -struct Handle { - miopenHandle_t handle; - Handle() : handle(NULL) { - MIOPEN_CHECK(miopenCreate(&handle)); - } - ~Handle() { - if (handle) { - miopenDestroy(handle); - } - } -}; +void createMIOpenHandle(miopenHandle_t *handle) { + MIOPEN_CHECK(miopenCreate(handle)); +} -std::mutex mutex; -std::unordered_map handles; +void destroyMIOpenHandle(miopenHandle_t handle) { +// this is because of something dumb in the ordering of +// destruction. Sometimes atexit, the cuda context (or something) +// would already be destroyed by the time this gets destroyed. It +// happens in fbcode setting. @colesbury and I decided to not destroy +// the handle as a workaround. +// - @soumith +// +// Further note: this is now disabled globally, because we are seeing +// the same issue as mentioned above in CUDA 11 CI. +// - @zasdfgbnm +// +// #ifdef NO_MIOPEN_DESTROY_HANDLE +// #else +// miopenDestroy(handle); +// #endif +} -} // namespace +using MIOpenPoolType = at::cuda::DeviceThreadHandlePool; +} // namespace -miopenHandle_t getMiopenHandle() -{ +miopenHandle_t getMiopenHandle() { int device; HIP_CHECK(hipGetDevice(&device)); - std::lock_guard guard(mutex); - return handles[device].handle; + // Thread local PoolWindows are lazily-initialized + // to avoid initialization issues that caused hangs on Windows. + // See: https://github.com/pytorch/pytorch/pull/22405 + // This thread local unique_ptrs will be destroyed when the thread terminates, + // releasing its reserved handles back to the pool. + static auto pool = std::make_shared(); + thread_local std::unique_ptr myPoolWindow( + pool->newPoolWindow()); + + auto handle = myPoolWindow->reserve(device); + MIOPEN_CHECK(miopenSetStream(handle, at::hip::getCurrentHIPStream())); + return handle; } }} // namespace at::native diff --git a/aten/src/ATen/miopen/Utils.h b/aten/src/ATen/miopen/Utils.h index 90ee4b7a14ee..5952e4f4c796 100644 --- a/aten/src/ATen/miopen/Utils.h +++ b/aten/src/ATen/miopen/Utils.h @@ -7,12 +7,6 @@ namespace at { namespace native { -inline void setMIOpenStreamToCurrent() { - // NB: Due to in-place HIPify, getCurrentCUDAStream actually means - // getCurrentHIPStream - MIOPEN_CHECK(miopenSetStream(getMiopenHandle(), at::hip::getCurrentHIPStream())); -} - // This function makes tensors which have zero stride contiguous, by // setting the strides to 1. inline Tensor contiguousIfZeroInStrides(const Tensor& t) { diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index 9802797874b9..9778aa035cb1 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -1,7 +1,6 @@ #include #include -#include -#include +#include namespace at { @@ -9,295 +8,66 @@ namespace native { namespace { - inline int start_index(int a, int b, int c) { - return (int)std::floor((float)(a * c) / b); - } - - inline int end_index(int a, int b, int c) { - return (int)std::ceil((float)((a + 1) * c) / b); - } - - template - static void adaptive_avg_pool2d_single_out_frame( - scalar_t *input_p, - scalar_t *output_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) - { - at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { - for (auto d = start; d < end; d++) - { - /* loop over output */ - int64_t oh, ow; - for(oh = 0; oh < osizeH; oh++) - { - int istartH = start_index(oh, osizeH, isizeH); - int iendH = end_index(oh, osizeH, isizeH); - int kH = iendH - istartH; - - for(ow = 0; ow < osizeW; ow++) - { - int istartW = start_index(ow, osizeW, isizeW); - int iendW = end_index(ow, osizeW, isizeW); - int kW = iendW - istartW; - - /* local pointers */ - scalar_t *ip = input_p + d*istrideD + istartH*istrideH + istartW*istrideW; - scalar_t *op = output_p + d*osizeH*osizeW + oh*osizeW + ow; - - /* compute local average: */ - scalar_t sum = 0; - int ih, iw; - for(ih = 0; ih < kH; ih++) - { - for(iw = 0; iw < kW; iw++) - { - scalar_t val = *(ip + ih*istrideH + iw*istrideW); - sum += val; - } - } - - /* set output to local average */ - *op = sum / kW / kH; - } - } - } - }); - } - - template - void adaptive_avg_pool2d_out_frame( - scalar_t *input_p, - scalar_t *output_p, - int64_t sizeB, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideB, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) - { - at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - adaptive_avg_pool2d_single_out_frame( - input_p + b * istrideB, - output_p + b * sizeD * osizeH * osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - }); - } - void adaptive_avg_pool2d_out_cpu_template( at::Tensor& output, at::Tensor const& input, IntArrayRef output_size) { TORCH_CHECK(output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); - for (int64_t i = 0; i < input.ndimension(); i++) { + int64_t ndim = input.ndimension(); + for (int64_t i = 0; i < ndim; i++) { TORCH_CHECK(input.size(i) > 0, "adaptive_avg_pooling2d(): expected input to have non-empty spatial dimensions, " "but input has sizes ", input.sizes(), " with dimension ", i, " being " "empty"); } - TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), + TORCH_CHECK((ndim == 3 || ndim == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); + TORCH_CHECK(input.dtype() == output.dtype(), + "expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype()); - /* sizes */ - int64_t sizeD = input.size(-3); - int64_t isizeH = input.size(-2); - int64_t isizeW = input.size(-1); - /* strides */ - int64_t istrideD = input.stride(-3); - int64_t istrideH = input.stride(-2); - int64_t istrideW = input.stride(-1); - - auto osizeH = output_size[0]; - auto osizeW = output_size[1]; - - /* resize output */ - if (input.ndimension() == 3 || input.size(-4) == 1) - { - if (input.ndimension() == 3) { - output.resize_({sizeD, osizeH, osizeW}); - } else { - output.resize_({1, sizeD, osizeH, osizeW}); - } - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - adaptive_avg_pool2d_single_out_frame( - input_data, - output_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - ); - } - else - { - int64_t sizeB = input.size(-4); - output.resize_({sizeB, sizeD, osizeH, osizeW}); - int64_t istrideB = input.stride(-4); + int64_t channels = input.size(-3); + int64_t input_height = input.size(-2); + int64_t input_width = input.size(-1); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - adaptive_avg_pool2d_out_frame( - input_data, - output_data, - sizeB, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideB, - istrideD, - istrideH, istrideW); - }); + if (ndim == 3) { + output.resize_({channels, output_height, output_width}); + } else { + int64_t nbatch = input.size(0); + output.resize_({nbatch, channels, output_height, output_width}, input.suggest_memory_format()); } - } - - template - static void adaptive_avg_pool2d_backward_single_out_frame( - scalar_t *gradInput_p, - scalar_t *gradOutput_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) - { - at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { - for (auto d = start; d < end; d++) - { - scalar_t *gradInput_p_d = gradInput_p + d*isizeW*isizeH; - scalar_t *gradOutput_p_d = gradOutput_p + d*osizeW*osizeH; - - /* calculate average */ - int64_t oh, ow; - for(oh = 0; oh < osizeH; oh++) - { - int istartH = start_index(oh, osizeH, isizeH); - int iendH = end_index(oh, osizeH, isizeH); - int kH = iendH - istartH; - for(ow = 0; ow < osizeW; ow++) - { - - int istartW = start_index(ow, osizeW, isizeW); - int iendW = end_index(ow, osizeW, isizeW); - int kW = iendW - istartW; - - scalar_t grad_delta = gradOutput_p_d[oh*osizeW +ow] / kH / kW; - - int ih, iw; - for(ih = istartH; ih < iendH; ih++) - { - for(iw = istartW; iw < iendW; iw++) - { - /* update gradient */ - gradInput_p_d[ih*isizeW + iw] += grad_delta; - } - } - } - } - } - }); - } - - template - void adaptive_avg_pool2d_backward_out_frame( - scalar_t *gradInput_p, - scalar_t *gradOutput_p, - int64_t sizeB, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) - { - at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - scalar_t *gradInput_p_d = gradInput_p + b * sizeD * isizeW * isizeH; - scalar_t *gradOutput_p_d = gradOutput_p + b * sizeD * osizeW * osizeH; - adaptive_avg_pool2d_backward_single_out_frame( - gradInput_p_d, - gradOutput_p_d, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - }); + adaptive_avg_pool2d_kernel(kCPU, output, input, output_size); } Tensor& adaptive_avg_pool2d_backward_out_cpu_template( - Tensor& gradInput, - const Tensor& gradOutput_, + Tensor& grad_input, + const Tensor& grad_output, const Tensor& input) { - /* sizes */ - int sizeD = input.size(-3); - int isizeH = input.size(-2); - int isizeW = input.size(-1); - int osizeH = gradOutput_.size(-2); - int osizeW = gradOutput_.size(-1); - - /* get contiguous gradOutput */ - auto gradOutput = gradOutput_.contiguous(); + int64_t ndim = grad_output.ndimension(); + for (int64_t i = 0; i < ndim; i++) { + TORCH_CHECK(grad_output.size(i) > 0, + "adaptive_avg_pooling2d_backward(): expected grad_output to have non-empty spatial dimensions, " + "but grad_output has sizes ", grad_output.sizes(), " with dimension ", i, " being " + "empty"); + } - /* backprop */ - if (input.ndimension() == 3 || input.size(-4) == 1) - { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data_ptr(); - scalar_t *gradOutput_data = gradOutput.data_ptr(); + TORCH_CHECK((ndim == 3 || ndim == 4), + "non-empty 3D or 4D (batch mode) tensor expected for grad_output"); + TORCH_CHECK(input.dtype() == grad_output.dtype(), + "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype()); + TORCH_CHECK(input.dtype() == grad_input.dtype(), + "expected dtype ", input.dtype(), " for `grad_input` but got dtype ", grad_input.dtype()); - adaptive_avg_pool2d_backward_single_out_frame( - gradInput_data, gradOutput_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); - } - else - { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data_ptr(); - scalar_t *gradOutput_data = gradOutput.data_ptr(); - int64_t sizeB = input.size(-4); + grad_input.resize_(input.sizes(), input.suggest_memory_format()); + grad_input.zero_(); - adaptive_avg_pool2d_backward_out_frame( - gradInput_data, gradOutput_data, - sizeB, sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); - } - return gradInput; + adaptive_avg_pool2d_backward_kernel(kCPU, grad_input, grad_output); + return grad_input; } } // namespace @@ -346,25 +116,27 @@ namespace { } Tensor& adaptive_avg_pool2d_backward_out_cpu( - Tensor& gradInput, - const Tensor& gradOutput, + Tensor& grad_input, + const Tensor& grad_output, const Tensor& input) { - gradInput.resize_as_(input); adaptive_avg_pool2d_backward_out_cpu_template( - gradInput, gradOutput, input); - return gradInput; + grad_input, grad_output, input); + return grad_input; } Tensor adaptive_avg_pool2d_backward_cpu( - const Tensor& gradOutput, + const Tensor& grad_output, const Tensor& input) { - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_input = at::empty({0}, input.options()); adaptive_avg_pool2d_backward_out_cpu_template( - gradInput, gradOutput, input); - return gradInput; + grad_input, grad_output, input); + return grad_input; } +DEFINE_DISPATCH(adaptive_avg_pool2d_kernel); +DEFINE_DISPATCH(adaptive_avg_pool2d_backward_kernel); + } // at::native } // at diff --git a/aten/src/ATen/native/AdaptivePooling.h b/aten/src/ATen/native/AdaptivePooling.h new file mode 100644 index 000000000000..29b2fd1c94c9 --- /dev/null +++ b/aten/src/ATen/native/AdaptivePooling.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel); +DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel); + +static inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (int64_t)std::floor((float)(a * c) / b); +} + +static inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return (int64_t)std::ceil((float)((a + 1) * c) / b); +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 9cc040b4dc8f..f8191c633d8b 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -78,6 +79,10 @@ extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex *a, extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info); extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info); +// geev +extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); +extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); + // gesdd extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, double *s, std::complex *u, int *ldu, std::complex *vt, int *ldvt, std::complex *work, int *lwork, double *rwork, int *iwork, int *info); @@ -305,6 +310,14 @@ template<> void lapackSyevd(char jobz, char uplo, int n, float *a, int ld ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); } +template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) { + dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); +} + +template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) { + sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); +} + template<> void lapackSvd, double>(char jobz, int m, int n, c10::complex *a, int lda, double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, double *rwork, int *iwork, int *info) { zgesdd_(&jobz, &m, &n, reinterpret_cast*>(a), &lda, s, reinterpret_cast*>(u), &ldu, @@ -1155,6 +1168,46 @@ std::tuple symeig_out(Tensor& vals, Tensor& vecs, const Tensor return std::tuple(vals, vecs); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +DEFINE_DISPATCH(eig_stub); + +std::tuple eig_out(Tensor& e, Tensor& v, const Tensor& self, bool eigenvectors) { + TORCH_CHECK(self.dim() == 2, "input should be 2 dimensional"); + TORCH_CHECK(self.size(0) == self.size(1), "input should be square"); + TORCH_CHECK(self.isfinite().all().item(), "input should not contain infs or NaNs"); + TORCH_CHECK(e.dtype() == self.dtype(), "Expected 'e' to have dtype ", self.dtype(), " but got ", e.dtype()); + if (eigenvectors) + TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype()); + int64_t n = self.size(-1); + + at::native::resize_output(e, {n, 2}); + if (eigenvectors) { + at::native::resize_output(v, self.sizes()); + } + + // optimization: if self is empty, we can immediately return the empty + // tensors, instead of getting empty tensors from eig_helper + if (self.numel() == 0) { + return std::tuple(e, v); + } + + Tensor vals_, vecs_; + std::tie(vals_, vecs_) = eig_stub(self.device().type(), self, eigenvectors); + e.copy_(vals_); + if (eigenvectors) { + v.copy_(vecs_); + } + return std::tuple(e, v); +} + +std::tuple eig(const Tensor& self, bool eigenvectors) { + Tensor e = at::empty({0}, self.options()); + Tensor v = at::empty({0}, self.options()); + at::eig_out(e, v, self, eigenvectors); + return std::tuple(e, v); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h new file mode 100644 index 000000000000..95fc2c6097ce --- /dev/null +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include // for USE_LAPACK + + +namespace at { namespace native { + +#ifdef USE_LAPACK +// Define per-batch functions to be used in the implementation of batched +// linear algebra operations + +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); + +#endif + +using eig_fn = std::tuple (*)(const Tensor&, bool&); + +DECLARE_DISPATCH(eig_fn, eig_stub); + +}} // namespace at::native diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp new file mode 100644 index 000000000000..d251245c60c5 --- /dev/null +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include + +#include // for USE_LAPACK + +namespace at { namespace native { + +namespace { + +template +void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int64_t* info_ptr) { +#ifndef USE_LAPACK + TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ", + "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); +#else + char jobvr = eigenvectors ? 'V' : 'N'; + int64_t n = self.size(-1); + auto self_data = self.data_ptr(); + + auto vals_data = vals_.data_ptr(); + scalar_t* wr = vals_data; + scalar_t* wi = vals_data + n; + + scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; + int ldvr = eigenvectors ? n : 1; + + if (n > 0) { + // call lapackEig once to get the optimal size for work data + scalar_t wkopt; + int info; + lapackEig('N', jobvr, n, self_data, n, wr, wi, + nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info); + int lwork = static_cast(wkopt); + + // call again to do the actual work + Tensor work = at::empty({lwork}, self.dtype()); + lapackEig('N', jobvr, n, self_data, n, wr, wi, + nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, &info); + *info_ptr = info; + } +#endif +} + +std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvectors) { + int64_t n = self.size(-1); + // lapackEig function expects the input to be column major, or stride {1, n}, + // so we must set the stride manually since the default stride for tensors is + // row major, {n, 1} + Tensor self_ = at::empty_strided( + {n, n}, + {1, n}, + at::TensorOptions(self.dtype())); + self_.copy_(self); + + auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor vals_ = at::empty_strided({n, 2}, {1, n}, options); + Tensor vecs_ = eigenvectors + ? at::empty_strided({n, n}, {1, n}, options) + : Tensor(); + + int64_t info; + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cpu", [&]{ + apply_eig(self_, eigenvectors, vals_, vecs_, &info); + }); + singleCheckErrors(info, "eig_cpu"); + return std::tuple(vals_, vecs_); +} + +} // anonymous namespace + +REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); +REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl); +REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); + +}} // namespace at::native diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index d6cb17418365..e8751be55387 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -11,6 +11,18 @@ #include namespace at { +namespace meta { + +TORCH_META_FUNC2(add, Tensor) ( + const Tensor& self, const Tensor& other, Scalar alpha +) { + build_binary_op(maybe_get_output(), self, other); + native::alpha_check(dtype(), alpha); +} + +} // namespace meta + + namespace native { DEFINE_DISPATCH(add_stub); @@ -57,24 +69,11 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) { return tensor; } -Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { - auto iter = TensorIterator::binary_op(result, self, other); - alpha_check(iter.dtype(), alpha); - add_stub(iter.device_type(), iter, alpha); - TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); - return result; -} - -Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { - Tensor result; - auto iter = TensorIterator::binary_op(result, self, other); - alpha_check(iter.dtype(), alpha); - add_stub(iter.device_type(), iter, alpha); - return iter.output(); -} - -Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { - return native::add_out(self, self, other, alpha); +TORCH_IMPL_FUNC(add_out) ( + Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha +) { + add_stub(device_type(), *this, alpha); + TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype()); } Tensor& add_relu_impl( @@ -148,7 +147,7 @@ Tensor& copysign_(Tensor& self, Scalar other) { return native::copysign_(self, wrapped_scalar_tensor(other)); } -Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { +Tensor& div_out(const Tensor& self, const Tensor& other, Tensor& result) { auto iter = TensorIterator::binary_float_op(result, self, other); div_stub(iter.device_type(), iter); return result; @@ -162,7 +161,7 @@ Tensor div(const Tensor& self, const Tensor& other) { } Tensor& div_(Tensor& self, const Tensor& other) { - return native::div_out(self, self, other); + return native::div_out(self, other, self); } // WARNING: There doesn't appear to be any testing for this function @@ -449,12 +448,15 @@ static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tens return wrapped_scalar_tensor(scalar); } +// TODO: Make this structured to undo the perf regression from native:: removal +// in call here + Tensor add(const Tensor& self, Scalar other, Scalar alpha) { - return native::add(self, wrapped_scalar_tensor(other), alpha); + return at::add(self, wrapped_scalar_tensor(other), alpha); } Tensor& add_(Tensor& self, Scalar other, Scalar alpha) { - return native::add_(self, wrapped_scalar_tensor(other), alpha); + return self.add_(wrapped_scalar_tensor(other), alpha); } Tensor remainder(const Tensor& self, Scalar other) { @@ -1099,37 +1101,5 @@ Tensor& ldexp_(Tensor& self, const Tensor& other) { return at::ldexp_out(self, self, other); } -// TODO: Deduplicate this with the TensorIterator logic. This would -// also fix the TODOs below. -Tensor binary_op_meta(const Tensor& self, const Tensor& other) { - // TODO: Doesn't do type promotion correctly - // TODO: Doesn't do strides correctly - int64_t dim = std::max(self.dim(), other.dim()); - std::vector sizes(dim); - for (int64_t i = 0; i < dim; i++) { - int64_t j = -1 - i; - if (i >= self.dim() || self.size(j) == 1) { - sizes[dim + j] = other.size(j); - } else if (i >= other.dim() || self.size(i) == 1) { - sizes[dim + j] = self.size(j); - } else { - TORCH_CHECK( - self.size(j) == other.size(j), - "Expected self.size(", j, ") == other.size(", j, "), but got ", self.size(j), " != ", other.size(j) - ); - sizes[dim + j] = self.size(j); - } - } - return at::empty_meta(sizes, self.options()); -} - -Tensor binary_op_with_scalar_meta(const Tensor& self, const Tensor& other, Scalar x) { - return binary_op_meta(self, other); -} - -TORCH_LIBRARY_IMPL(aten, Meta, m) { - m.impl("add.Tensor", binary_op_with_scalar_meta); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index d76dd9d205e9..1fdb80590b5a 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -25,13 +25,15 @@ inline void sub_check(const Tensor& self, const Tensor& other) { "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); } +using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, Scalar alpha); + using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha); using binary_fn_beta = void(*)(TensorIterator&, double beta); using binary_fn = void(*)(TensorIterator&); using binary_clamp_fn_alpha = void(*)(TensorIterator&, Scalar alpha, Scalar min_val, Scalar max_val); -DECLARE_DISPATCH(binary_fn_alpha, add_stub); +DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub); DECLARE_DISPATCH(binary_fn_alpha, sub_stub); DECLARE_DISPATCH(binary_fn, mul_stub); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 6dbf1e5535ed..d55c4fca6027 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -177,14 +177,10 @@ auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at int64_t outsize = 1; if (transposed) { std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); - for (int64_t i = 1; i < o.size(); i++) { - outsize *= o[i]; - } + outsize = prod_intlist(o.begin() + 1, o.end()); } else { std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); - for (int64_t i = 1; i < o.size(); i++) { - outsize *= o[i]; - } + outsize = prod_intlist(o.begin() + 1, o.end()); } return outsize > int_max; } diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index d4c106477fe7..0c562f363731 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -11,12 +11,18 @@ namespace at { namespace native { static CPUCapability compute_cpu_capability() { auto envar = std::getenv("ATEN_CPU_CAPABILITY"); if (envar) { +#ifdef HAVE_VSX_CPU_DEFINITION + if (strcmp(envar, "vsx") == 0) { + return CPUCapability::VSX; + } +#else if (strcmp(envar, "avx2") == 0) { return CPUCapability::AVX2; } if (strcmp(envar, "avx") == 0) { return CPUCapability::AVX; } +#endif if (strcmp(envar, "default") == 0) { return CPUCapability::DEFAULT; } @@ -33,7 +39,11 @@ static CPUCapability compute_cpu_capability() { } } #endif +#ifdef HAVE_VSX_CPU_DEFINITION + return CPUCapability::VSX; +#else return CPUCapability::DEFAULT; +#endif } CPUCapability get_cpu_capability() { diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 63e2462489be..0368fa9741e9 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -47,8 +47,12 @@ namespace at { namespace native { enum class CPUCapability { DEFAULT = 0, +#ifdef HAVE_VSX_CPU_DEFINITION + VSX = 1, +#else AVX = 1, AVX2 = 2, +#endif NUM_OPTIONS }; @@ -101,6 +105,12 @@ struct CAFFE2_API DispatchStub { AT_ASSERTM(AVX, "DispatchStub: missing AVX kernel"); return AVX; } +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::VSX)) { + AT_ASSERTM(VSX, "DispatchStub: missing VSX kernel"); + return VSX; + } #endif AT_ASSERTM(DEFAULT, "DispatchStub: missing default kernel"); return DEFAULT; @@ -124,6 +134,9 @@ struct CAFFE2_API DispatchStub { #ifdef HAVE_AVX2_CPU_DEFINITION static FnPtr AVX2; #endif +#ifdef HAVE_VSX_CPU_DEFINITION + static FnPtr VSX; +#endif }; namespace { @@ -173,10 +186,17 @@ struct RegisterHIPDispatch { #define REGISTER_AVX2_DISPATCH(name, fn) #endif +#ifdef HAVE_VSX_CPU_DEFINITION +#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) +#else +#define REGISTER_VSX_DISPATCH(name, fn) +#endif + #define REGISTER_NO_CPU_DISPATCH(name, fn_type) \ REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast(nullptr)) \ REGISTER_AVX_DISPATCH(name, static_cast(nullptr)) \ - REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) + REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_VSX_DISPATCH(name, static_cast(nullptr)) #define REGISTER_CUDA_DISPATCH(name, fn) \ static RegisterCUDADispatch name ## __register(name, fn); diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index 5fbc1506bfaa..90006c74346d 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -188,6 +188,9 @@ FOREACH_UNARY_OP(sinh); FOREACH_UNARY_OP(round); FOREACH_UNARY_OP(lgamma); FOREACH_UNARY_OP(frac); +FOREACH_UNARY_OP(trunc); +FOREACH_UNARY_OP(reciprocal); +FOREACH_UNARY_OP(sigmoid); FOREACH_POINTWISE_OP_SCALAR(addcdiv); FOREACH_POINTWISE_OP_SCALAR(addcmul); @@ -201,7 +204,7 @@ std::vector foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList \ std::vector result; \ result.reserve(tensors1.size()); \ - for (int i = 0; i < tensors1.size(); i++) { \ + for (size_t i = 0; i < tensors1.size(); i++) { \ result.emplace_back(at::NAME(tensors1[i], tensors2[i])); \ } \ \ diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index bac2f80e8a7c..c9e03aaa3b6b 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -136,334 +136,241 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra return result; } -// There are roughly three parts to compute einsum: -// 1. Parse equation to extract the labels for each input operand and output -// 2. Unsqueeze missing dimensions from input operands and permute to align them -// 3. Compute result by multiplying input operands and summing contraction -// dimensions We do the last part by reducing to bmm. -Tensor einsum(std::string equation, TensorList operands) { - TORCH_CHECK(!operands.empty(), "einsum() must provide at least one operand"); - checkDeviceType("einsum()", operands, operands[0].device().type()); - - // Code for encoding ellipsis ("...") with labels - constexpr int ELLIPSIS = '.'; - - // Find arrow (->) to split equation into lhs and rhs - const auto arrow_pos = equation.find("->"); - const auto lhs = equation.substr(0, arrow_pos); - - // Convert labels for input operands into an index in [0, 25] and store - // them in op_labels for each operand along with ELLIPSIS. - std::vector> op_labels(operands.size()); - bool found_ell = false; - std::string::size_type curr_op = 0; - for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { - switch (lhs[i]) { - case ' ': - // Ignore spaces - break; - - case '.': - TORCH_CHECK( - // Only one ellipsis per operand can be given - !found_ell, - "einsum() found \'.\' for operand ", - curr_op, - " for which an ellipsis was already found"); - TORCH_CHECK( - // Ensure it's a valid ellipsis - i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.', - "einsum() found \'.\' for operand ", - curr_op, - " that is not part of any ellipsis"); - op_labels[curr_op].push_back(ELLIPSIS); - found_ell = true; - break; - - case ',': - // Move onto next operand - ++curr_op; - TORCH_CHECK( - curr_op < operands.size(), - "einsum() fewer operands were provided than specified in the equation"); - found_ell = false; - break; - - default: - // Parse label - TORCH_CHECK( - lhs[i] >= 'a' && lhs[i] <= 'z', - "einsum() operand subscript must be in range [a, z] but found ", - lhs[i], - " for operand ", - curr_op); - // Convert label to index in [0, 25] and store - op_labels[curr_op].push_back(lhs[i] - 'a'); - } +Tensor einsum(std::string eqn, TensorList tensors) { + constexpr size_t number_of_letters = 26; + std::string in_eqn; + size_t pos; + // The equation is given in terms of single lowercase letters ('a'..'z') and potentially an ellipsis. + // Internally, we represent it using indices from 0 to num_total_dimensions, with each letter + // mapped to an index and the ellipsis ('...') being mapped to a number of consequtive indices. + // The mapping of letters to internal indices is given in letter_mapping. A value of -1 means that + // the letter has not been assigned an index yet (because it has not been seen). + // The ellipsis is defined by first_ell_idx (the first index) and num_ell_idxes (the number of indices). + // A value of -1 for num_ell_idxes specifies that we have not seen an ellipsis yet. + // Note: The internal indices are NOT the dimensions used internally. There is a mapping to them below. + + std::array letter_mapping; // map letter to internal (numerical) label + letter_mapping.fill(-1); + int64_t num_ell_idxes = -1; + int64_t first_ell_idx = 0; + + // The internal representation of the left hand side fo the equation (with ellipsis expanded) is stored in input_op_idxes. + // For each operand, we have a vector mapping each dimension to an internal index. + // We also keep track of the number of occurrences for each letter (to infer a right hand side if not given) and + // of the last occurrence of each index. + std::vector> input_op_idxes; // the parsed operand indices + std::array num_letter_occurrences; // number of occurrence in the equation of this letter + num_letter_occurrences.fill(0); + std::vector last_idx_occurrence; // the last operator (left to right) using this index + + if ((pos = eqn.find("->")) != std::string::npos) { // check whether we have a right hand side. in_eq is the left hand side + in_eqn = eqn.substr(0, pos); + } else { + in_eqn = eqn; } - - TORCH_CHECK( - curr_op == operands.size() - 1, - "einsum() more operands were provided than specified in the equation"); - - // Labels must be within [a, z]. - constexpr int TOTAL_LABELS = 'z' - 'a' + 1; - std::vector label_count(TOTAL_LABELS, 0); - - // The maximum number of dimensions covered by any ellipsis, needed when - // unsqueezing missing dimensions from operands to permute and broadcast - int64_t ell_num_dim = 0; - - // Compute label frequency and number of dimensions covered by ellipsis - // We do this after parsing labels to make it more readable and simpler - // to compute the number of dimensions covered by ellipsis. - for (std::size_t i = 0; i < operands.size(); ++i) { - const Tensor operand = operands[i]; - std::vector labels = op_labels[i]; - int64_t nlabels = labels.size(); - int64_t ndims = operand.dim(); - bool has_ellipsis = false; - - for (int label : labels) { - if (label == ELLIPSIS) { - --nlabels; - has_ellipsis = true; - ell_num_dim = std::max(ell_num_dim, ndims - nlabels); - } else { - ++label_count[label]; + // remove spaces for einsum compatibility (#9929) + in_eqn.erase(std::remove_if(in_eqn.begin(), in_eqn.end(), isspace), in_eqn.end()); + + // next we parse in_eq (the left hand side) by iterating. It is a string of comma separated terms per index + int64_t operand = 0; + std::stringstream eqn_stream(in_eqn); + std::string term; + int64_t num_total_idxes = 0; + while (! eqn_stream.eof()) { + std::getline(eqn_stream, term, ','); // term = string with indices of current term + TORCH_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we use the dimension + + int64_t ell_char_count = 0; // handling of ellipsis '...' is a bit tedious, we count the '.' + // if there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions + int64_t candidate_num_ell_idxes = tensors[operand].dim() - term.size() + 3; + int64_t dims_in_term = 0; // dimensions we have seen + std::vector current_op_idxes; // mapping of operand dimensions to indices for current term + for (auto &c : term) { // c = character with a single letter or '.' + if (c == '.') { + ell_char_count++; + TORCH_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in term ", operand, " of the equation"); + if (ell_char_count == 3) { // this completes the ellipsis + if (num_ell_idxes == -1) { // if we have not seen an ellipsis before, keep track of indices and size + first_ell_idx = num_total_idxes; + num_ell_idxes = candidate_num_ell_idxes; + num_total_idxes += num_ell_idxes; + } + else { // we have seen an ellipsis before, so we check compatibility + TORCH_CHECK(candidate_num_ell_idxes == num_ell_idxes, + "ellipsis must represent ", num_ell_idxes, " dimensions in all terms"); + } + for (int64_t i = 0; i < num_ell_idxes; ++i) { // map ellipsis dimensions in operand to indices + current_op_idxes.push_back(first_ell_idx + i); + last_idx_occurrence.push_back(operand); + } + dims_in_term += num_ell_idxes; // keep track of dimensions + } + } else { // a letter (hopefully) + TORCH_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis, operand ", operand); + TORCH_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices"); + int64_t letter_num = c-'a'; // letter_num = position in letter_mapping + if (letter_mapping[letter_num] == -1) { // new letter, add internal index and mapping + letter_mapping[letter_num] = num_total_idxes; + num_total_idxes++; + last_idx_occurrence.push_back(operand); + } else { // letter we have already seen + last_idx_occurrence[letter_mapping[letter_num]] = operand; + } + num_letter_occurrences[letter_num]++; + current_op_idxes.push_back(letter_mapping[letter_num]); + dims_in_term++; } } - - TORCH_CHECK( - has_ellipsis ? nlabels <= ndims : nlabels == ndims, - "einsum() the number of subscripts in the equation (", - nlabels, - has_ellipsis ? ") is more than the number of dimensions (" - : ") does not match the number of dimensions (", - ndims, - ") for operand ", - i, - has_ellipsis ? "" : " and no ellipsis was given"); + TORCH_CHECK(dims_in_term == tensors[operand].dim(), "dimension mismatch for operand ", operand, ": equation ", dims_in_term, " tensor ", tensors[operand].dim()); + input_op_idxes.push_back(std::move(current_op_idxes)); + operand++; } - - // Mapping of label to index in the permuted tensors (out_dims + sum_dims) - // This will be used for aligning the dimensions of all input operands - std::vector label_perm_index(TOTAL_LABELS, -1); - - // Current index in the permuted shape - int perm_index = 0; - - // Start index of ellipsis dimensions in the permuted shape - int64_t ell_index = 0; - found_ell = false; - - if (arrow_pos == std::string::npos) { - // Implicit output is ellipsis (...) + labels seen only once - perm_index = ell_num_dim; - found_ell = true; - for (int label = 0; label < TOTAL_LABELS; ++label) { - if (label_count[label] == 1) { - label_perm_index[label] = perm_index++; + // in the check below, we need ==, but > is captured above, so the error message can be specific that it is <. + TORCH_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation"); + + // the following parses or infers output (right hand side) + // it also assigns the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors) + // for the output indices. -1 means that the index has not been assigned a dimension yet + std::vector idxes_to_preprocessed_dims(num_total_idxes, -1); // the position of the index in the tensor dimensions + int64_t num_output_dims = 0; + if (pos != std::string::npos) { // parse the user provided right hand side + int64_t ell_char_count = 0; + for (auto &c : eqn.substr(pos+2)) { + if (c == '.') { // '.' as part of ellipsis + ell_char_count++; + TORCH_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in right hand side of the equation"); + if (ell_char_count == 3) { // ellipsis complete + TORCH_CHECK(num_ell_idxes >= 0, "ellipsis '...' may only appear in right hand side if it does in left hand side"); + for (int64_t i = 0; i < num_ell_idxes; ++i) { + idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims; + num_output_dims++; + } + } + } else if (! isspace(c)) { // letter (hopefully) + TORCH_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis in the right hand side"); + TORCH_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices"); + int64_t letter_num = c-'a'; + TORCH_CHECK(idxes_to_preprocessed_dims[letter_mapping[letter_num]] == -1, "index ", c, " occurs twice in output"); + idxes_to_preprocessed_dims[letter_mapping[letter_num]] = num_output_dims; + num_output_dims++; } } - } else { - // Parse explicit output - const std::string rhs = equation.substr(arrow_pos + 2); - for (std::size_t i = 0; i < rhs.length(); ++i) { - switch (rhs[i]) { - case ' ': - // Ignore spaces - break; - - case '.': - TORCH_CHECK( - // There can only be one ellipsis in the output - !found_ell, - "einsum() found \'.\' for output but an ellipsis (...) was already found"); - TORCH_CHECK( - // Ensure ellipsis is correct - i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.', - "einsum() found \'.\' for output that is not part of any ellipsis (...)"); - ell_index = perm_index; - perm_index += ell_num_dim; - found_ell = true; - break; - - default: - TORCH_CHECK( - rhs[i] >= 'a' && rhs[i] <= 'z', - "einsum() subscripts must be in range [a, z] but found ", - rhs[i], - " for the output"); - TORCH_CHECK( - // Ensure label appeared at least once for some input operand and at - // most once for the output - label_count[rhs[i] - 'a'] > 0, - "einsum() output subscript ", - rhs[i], - label_count[rhs[i] - 'a'] == -1 - ? " appears more than once in the output" - : " does not appear in the equation for any input operand"); - label_perm_index[rhs[i] - 'a'] = perm_index++; - - // Set to -1 to mark that this label already appeared in the output - label_count[rhs[i] - 'a'] = -1; + } else { // create an inferred right hand side + // the ellipsis (if in the lhs) comes first + if (num_ell_idxes >= 0) { + for (int64_t i = 0; i < num_ell_idxes; ++i) { + idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims; + num_output_dims++; + } + } + // then the indices that occur exactly once in alphabetic order + for (size_t idx = 0; idx < number_of_letters; idx++) { + if (num_letter_occurrences[idx] == 1) { + idxes_to_preprocessed_dims[letter_mapping[idx]] = num_output_dims; + num_output_dims++; } } } - - // Save output size before adding sum dims - const int out_size = perm_index; - - // If ellipsis is not part of the output, add to contraction dimensions - if (ell_num_dim > 0 && !found_ell) { - ell_index = perm_index; - perm_index += ell_num_dim; - } - - // Add contraction labels (labels not present in output) - for (int label = 0; label < TOTAL_LABELS; ++label) { - if (label_count[label] > 0 && label_perm_index[label] == -1) { - label_perm_index[label] = perm_index++; + // now we assign the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors) + // for the non-output indices - those that are eventually summed over + int64_t position = num_output_dims; + for (int64_t i = 0; i < num_total_idxes; i++) { + if (idxes_to_preprocessed_dims[i]==-1) { + idxes_to_preprocessed_dims[i] = position; + position++; } } - // Here we unsqueeze missing dimensions to make all operands have the same - // number of dimensions. We take diagonals for repeated labels within the - // same operand. Finally we permute the operands to align dimensions as - // per the perm_out_index we computed above. - std::vector permuted_operands; - for (std::size_t i = 0; i < operands.size(); ++i) { - std::vector perm_shape(perm_index, -1); - std::vector label_dim(TOTAL_LABELS, -1); - const std::vector labels = op_labels[i]; - Tensor operand = operands[i]; - const auto sizes = operand.sizes(); - std::size_t j = 0; - - for (int label : labels) { - if (label == ELLIPSIS) { - // Add missing dimensions under ellipsis - int64_t num_dim_diff = - ell_num_dim - (operand.dim() - labels.size() + 1); - for (int64_t k = 0; k < num_dim_diff; ++k) { - operand = operand.unsqueeze(j); + // we now "homogenize the dimensions", i.e. + // - take diagonals for duplicated indices + // - permute the dimensions to match the order given by idxes_to_preprocessed_dims + // - unsqueeze to create all dimensions for each index in each tensor where they are missing + // we also check that sizes match + // after this, all operands will have compatible shapes (i.e. all dimensions are aligned are broadcastable) + std::vector preprocessed_operands; + std::vector size_of_dims(num_total_idxes, -1); // keep track of sizes for each index, -1 means we have not seen a size yet + for (int64_t op = 0; op < (int64_t) tensors.size(); op++) { + auto preprocessed_op = tensors[op]; + std::vector idx_to_dim(num_total_idxes, -1); // the dimension which the index refers to in the original tensor, -1 means it does not appear + std::vector& current_op_input_idxes = input_op_idxes[op]; + int64_t dim = 0; // there are two dimension indices: dim is after taking diagonals, i is in input + for (size_t i = 0; i < current_op_input_idxes.size(); i++) { + auto idx = current_op_input_idxes[i]; + auto dim_out = idxes_to_preprocessed_dims[idx]; + if (idx_to_dim[dim_out] == -1) { // first appearance + idx_to_dim[dim_out] = dim; + if (size_of_dims[idx] == -1) { // keep track of sizes + size_of_dims[idx] = preprocessed_op.size(dim); + } + else { + TORCH_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i); } - for (int64_t k = 0; k < ell_num_dim; ++k) { - perm_shape[ell_index + k] = j++; + dim++; + } else { // duplicate dimension in tensor --> take diagonal of idx_to_dim[dim_out] and dim and put the diagonal dimension to idx_to_dim[dim_out] + TORCH_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i); + preprocessed_op = preprocessed_op.diagonal(0, idx_to_dim[dim_out], dim); + // diagonal moves the diagonal dimension to the back + // now we permute the last dim back to idx_to_dim[dim_out] + std::vector perm(preprocessed_op.dim(), 0); + for (int64_t d = 0; d < preprocessed_op.dim(); d++) { + if (d == idx_to_dim[dim_out]) { + perm[d] = preprocessed_op.dim() - 1; + } else { + perm[d] = d - (d > idx_to_dim[dim_out]); + } } - } else if (label_dim[label] != -1) { - // Repeated label, take diagonal - int64_t dim = label_dim[label]; - TORCH_CHECK( - sizes[j] == sizes[dim], - "einsum() subscript ", - char(label + 'a'), - " is repeated for operand ", - i, - " but the sizes don't match, ", - sizes[j], - " != ", - sizes[dim]); - operand = operand.diagonal(0, j, dim).movedim(-1, dim); - } else { - // Lookup output index for label - label_dim[label] = j; - perm_shape[label_perm_index[label]] = j++; + preprocessed_op = preprocessed_op.permute(perm); } } - - // Add dimensions for missing labels - for (int64_t& index : perm_shape) { - if (index == -1) { - operand = operand.unsqueeze(-1); - index = j++; + // now we permute the dimensions in the right order + std::vector permutation; // permutation for this tensor + for (auto &d : idx_to_dim) { + if (d > -1) { + permutation.push_back(d); } } - - permuted_operands.push_back(operand.permute(perm_shape)); - } - - // Check if operands broadcast and keep track of last operand with - // dimension size != 1 for optimizing reductions - std::vector dim_last_op(perm_index, 0); - bool has_zero_size_dim = false; - for (int dim = 0; dim < perm_index; ++dim) { - int64_t broadcast_size = permuted_operands[0].size(dim); - for (std::size_t i = 1; i < permuted_operands.size(); ++i) { - int64_t dim_size = permuted_operands[i].size(dim); - if (broadcast_size != dim_size && broadcast_size != 1 && dim_size != 1) { - std::ostringstream msg; - msg << "einsum() operands do not broadcast with remapped shapes [original->remapped]:"; - for (std::size_t j = 0; j < operands.size(); ++j) { - msg << " " << operands[j].sizes() << "->" - << permuted_operands[j].sizes(); - } - TORCH_CHECK(false, msg.str()); - } - if (dim_size != 1) { - broadcast_size = dim_size; - dim_last_op[dim] = i; + preprocessed_op = preprocessed_op.permute(permutation); + // finally, we insert dimensions for idxes not in the operand + for (size_t dim = 0; dim < idx_to_dim.size(); dim++) { + if (idx_to_dim[dim] == -1) { + preprocessed_op = preprocessed_op.unsqueeze(dim); } } - has_zero_size_dim |= broadcast_size == 0; - } - - // Compute result - Tensor result = permuted_operands[0]; - // Fast path for when an operand has zero sized dim - if (has_zero_size_dim) { - std::vector out_shape(out_size); - for (int i = 0; i < out_size; ++i) { - out_shape[i] = permuted_operands[dim_last_op[i]].size(i); - } - return at::zeros(out_shape, result.options()); + preprocessed_operands.push_back(std::move(preprocessed_op)); } - // Sum out or squeeze dimensions that are size 1 for all later operands - int dim = out_size; - for (int i = dim; i < perm_index; ++i, ++dim) { - if (dim_last_op[i] == 0) { - if (result.size(dim) == 1) { - result = result.squeeze(dim--); - } else { - result = result.sum(dim--); - } + // now we reduce the indices from left to right + // numpy allows to optimize the path using various + // algorithms (see eigen_path in numpy docs) + // we start with the leftmost operator and reduce indices that + // appear only there + Tensor result = std::move(preprocessed_operands[0]); + for (int64_t idx = 0; idx < num_total_idxes; idx++) { + if ((last_idx_occurrence[idx] == 0) + && (idxes_to_preprocessed_dims[idx]>=num_output_dims)) { + result = result.sum(idxes_to_preprocessed_dims[idx], true); } } - for (std::size_t i = 1; i < permuted_operands.size(); ++i) { - Tensor operand = permuted_operands[i]; + // now we process each tensor using sumproduct_pair + for (int64_t i = 1; i < (int64_t) preprocessed_operands.size(); i++) { std::vector sum_dims; - - // Sum out or squeeze dimensions that are size 1 for all later operands - dim = out_size; - for (int j = dim; j < perm_index; ++j, ++dim) { - if (dim_last_op[j] < i) { - operand = operand.squeeze(dim); - --dim; - } else if (dim_last_op[j] == i) { - if (result.size(dim) == 1) { - operand = operand.sum(dim); - result = result.squeeze(dim); - --dim; - } else { - sum_dims.push_back(dim); - } + for (int64_t idx = 0; idx < num_total_idxes; idx++) { + if ((last_idx_occurrence[idx] == i) + && (idxes_to_preprocessed_dims[idx]>=num_output_dims)) { + sum_dims.push_back(idxes_to_preprocessed_dims[idx]); } } - - // Multiply tensors and sum out dimensions in sum_dims - if (sum_dims.empty()) { - result = result.mul(operand); - } else if (sum_dims.size() == result.sizes().size()) { - result = result.flatten().dot(operand.flatten()); - } else { - result = sumproduct_pair(result, operand, sum_dims, false); - } + result = at::native::sumproduct_pair(result, std::move(preprocessed_operands[i]), sum_dims, true); + } + // finally, we squeeze out all non-result dimensions + auto sizes = result.sizes().vec(); + for (int64_t dim = num_total_idxes-1; dim >= num_output_dims; dim--) { + sizes.erase(sizes.begin() + dim); } + result = result.view(sizes); return result; } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index bbc8d29dfab7..1c3b9ca60c1c 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1463,14 +1463,13 @@ Tensor matrix_power(const Tensor& a, int64_t n) { } Tensor frobenius_norm(const Tensor& self) { - TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors"); return at::norm(self); } Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { // NOTE: As frobenius_norm_out is currently implemented, it will always produce a // strided tensor result, even if the input is sparse. - auto options = self.options().layout(c10::Layout::Strided); + auto options = self.options().layout(c10::Layout::Strided).dtype(toValueType(self.scalar_type())); Tensor result = at::empty({0}, options); return at::native::frobenius_norm_out(result, self, dim, keepdim); } @@ -1480,7 +1479,6 @@ Tensor &frobenius_norm_out( const Tensor& self, IntArrayRef dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors"); TORCH_CHECK( dim.size() <= 2, "Expected at most 2 dimensions, but got ", @@ -1524,7 +1522,7 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) { } Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + Tensor result = at::empty({0}, self.options().dtype(toValueType(self.scalar_type()))); return at::native::nuclear_norm_out(result, self, dim, keepdim); } @@ -1679,7 +1677,7 @@ static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optio // when the input contains extreme values (like nan or +/-inf) or if the input // size is degenerate (like size(0), size(0, N), etc) case_was_overridden = true; - self_ = self.abs(); + self_ = self_.abs(); result_ = _norm_min_max(self_, ord, dim[0], keepdim); } else if ((self_.numel() == 0) && (ord < 0)) { // For negative orders with degenerate input sizes, at::norm's result does not @@ -1698,7 +1696,7 @@ static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optio } if (!case_was_overridden) { if (opt_dtype.has_value()) { - result_ = at::norm(self, opt_ord, dim, keepdim, opt_dtype.value()); + result_ = at::norm(self.to(opt_dtype.value()), opt_ord, dim, keepdim); } else { result_ = at::norm(self, opt_ord, dim, keepdim); } @@ -1749,14 +1747,14 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional // Numerical or None norms Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { - auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device()); Tensor result = at::empty({0}, options); return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); } // Frobenius and nuclear norms Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { - auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device()); Tensor result = at::empty({0}, options); return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } @@ -1781,7 +1779,8 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) { "linalg_cond does not support yet this case."); } auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); - Tensor result = at::full(result_shape, INFINITY, self.options()); + TensorOptions options = self.options().dtype(toValueType(self.scalar_type())); + Tensor result = at::full(result_shape, INFINITY, options); return result; } @@ -1816,7 +1815,8 @@ Tensor _linalg_cond_helper(const Tensor& self, c10::variant // Return zero for each matrix in the batch Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) { auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); - return at::zeros(result_shape, self.options().dtype(dtype)); + TensorOptions options = self.options().dtype(toValueType(self.scalar_type())); + return at::zeros(result_shape, options); } void _linalg_cond_check_ord(c10::variant ord_variant) { @@ -1849,8 +1849,7 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input if (self.numel() == 0) { auto real_dtype = toValueType(typeMetaToScalarType(self.dtype())); - auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); - return _linalg_cond_empty_matrix(self, expected_dtype); + return _linalg_cond_empty_matrix(self, real_dtype); } // If ord == None or ord == ±2 @@ -1883,10 +1882,9 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt // the result is always real-valued, for other cases it is complex-valued for the complex-valued input. ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; - auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); - TORCH_CHECK(result.scalar_type() == expected_dtype, - "result dtype ", result.scalar_type(), " does not match the expected dtype ", expected_dtype); + TORCH_CHECK(result.scalar_type() == real_dtype, + "result dtype ", result.scalar_type(), " does not match the expected dtype ", real_dtype); Tensor result_tmp = at::linalg_cond(self, opt_ord); at::native::resize_output(result, result_tmp.sizes()); @@ -1916,8 +1914,9 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { // TODO: implement _out variant avoiding copy and using already allocated storage directly Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { - TORCH_CHECK(result.scalar_type() == self.scalar_type(), - "result dtype ", result.scalar_type(), " does not match the expected dtype ", self.scalar_type()); + ScalarType real_type = toValueType(self.scalar_type()); + TORCH_CHECK(result.scalar_type() == real_type, + "result dtype ", result.scalar_type(), " does not match the expected dtype ", real_type); Tensor result_tmp = at::linalg_cond(self, ord); at::native::resize_output(result, result_tmp.sizes()); diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 2ddcf5bd5c16..1ac4250a9d54 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -415,6 +415,7 @@ std::tuple _batch_norm_impl_index( bool use_cudnn = false; use_cudnn = (input.is_cuda() + && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16 && (input.scalar_type() != at::kHalf || weight.scalar_type() == at::kFloat) && weight.defined() && bias.defined() diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp index 14c126f77bdf..e6301e682d77 100644 --- a/aten/src/ATen/native/PixelShuffle.cpp +++ b/aten/src/ATen/native/PixelShuffle.cpp @@ -4,31 +4,51 @@ #include #include +#include #include namespace at { namespace native { Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { - AT_ASSERTM(self.dim() == 4, - "pixel_shuffle expects 4D input, but got input with sizes ",self.sizes()); - int64_t b = self.size(0); - int64_t c = self.size(1); - int64_t h = self.size(2); - int64_t w = self.size(3); + TORCH_CHECK(self.dim() >= 3, + "pixel_shuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + // Format: (B1, ..., Bn), C, H, W + int64_t c = self.size(-3); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + const auto NUM_NON_BATCH_DIMS = 3; + const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS; + int64_t upscale_factor_squared = upscale_factor * upscale_factor; - AT_ASSERTM(c % upscale_factor_squared == 0, - "pixel_shuffle expects input channel to be divisible by square of " - "upscale_factor, but got input with sizes ", self.sizes(), - ", upscale_factor=", upscale_factor, - ", and self.size(1)=", c, " is not divisible by ", upscale_factor_squared); + TORCH_CHECK(c % upscale_factor_squared == 0, + "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " + "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared); int64_t oc = c / upscale_factor_squared; int64_t oh = h * upscale_factor; int64_t ow = w * upscale_factor; - auto input_reshaped = self.reshape({b, oc, upscale_factor, upscale_factor, h, w}); - return input_reshaped.permute({0 /* b */, 1 /* oc */, 4 /* h */, 2 /* 1st upscale_factor */, 5 /* w */, 3 /* 2nd upscale_factor */}) - .reshape({b, oc, oh, ow}); + // First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor). + // This allows shuffling to be done next by permuting dims. + std::vector expanded_shape(self.sizes().begin(), last_batch_dim); + expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); + const auto input_expanded = self.reshape(expanded_shape); + + // Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims. + std::vector permutation(self.sizes().begin(), last_batch_dim); + // std::iota is used to maintain the batch dims within the permutation. + // Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6. + std::iota(permutation.begin(), permutation.end(), -expanded_shape.size()); + permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */, + -3 /* 2nd upscale_factor */}); + const auto input_permuted = input_expanded.permute(permutation); + + // Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh) + // and (w, upscale_factor) -> a single dim (ow). + std::vector final_shape(self.sizes().begin(), last_batch_dim); + final_shape.insert(final_shape.end(), {oc, oh, ow}); + return input_permuted.reshape(final_shape); } }} // namespace at::native diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 364c112572b5..e4b0a1cb19b7 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -660,22 +660,23 @@ Tensor& logsumexp_out(Tensor& result, const Tensor& self, DimnameList dims, bool static Tensor& norm_out(Tensor &result, const Tensor &self, optional opt_p, IntArrayRef dim, bool keepdim, optional opt_dtype) { - auto p = opt_p.value_or(2.0); - TORCH_CHECK(!(p.toDouble() == 2 && self.is_complex()), "norm with p=2 not supported for complex tensors"); + auto p = opt_p.value_or(2.0).to(); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "norm only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "norm only supports strided layout, got: ", self.layout()); - ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); + ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); TORCH_CHECK( - at::isFloatingType(scalarType) || at::isComplexType(scalarType), - "Can only calculate the mean of floating types. Got ", - toString(scalarType), + at::isFloatingType(in_dtype) || at::isComplexType(in_dtype), + "Can only calculate the norm of floating point and complex dtypes. Got ", + toString(in_dtype), " instead."); - ScalarType dtype = get_dtype(result, self, opt_dtype, true); - auto iter = make_reduction("norm", result, self, dim, keepdim, dtype); + ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())); + + auto iter = make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype); + if (iter.numel() == 0) { result.zero_(); } else { diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index 437a39bf2b92..4106a90c0729 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -2,6 +2,8 @@ // Please note that this file is // used across both CPU and GPU. +#include +#include #include #include #include @@ -157,11 +159,15 @@ struct MeanOps { } }; -template +// This accumulator template is used to calculate the minimum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct AbsMinOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return MIN(acc, acc_t(std::abs(data))); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MIN(acc, static_cast(std::abs(data))); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -177,17 +183,21 @@ struct AbsMinOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template +// This accumulator template is used to calculate the maximum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct AbsMaxOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return MAX(acc, acc_t(std::abs(data))); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MAX(acc, static_cast(std::abs(data))); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -203,18 +213,22 @@ struct AbsMaxOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template +// This accumulator template is used to calculate the norm of the absolute value +// of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormOps { acc_t norm_; - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + compat_pow(std::abs(data), norm_); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + compat_pow(static_cast(std::abs(data)), norm_); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -222,7 +236,7 @@ struct NormOps { } inline C10_DEVICE acc_t project(acc_t a) const { - return compat_pow(a, acc_t(1.0)/norm_); + return compat_pow(a, static_cast(1.0) / norm_); } static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { @@ -230,8 +244,8 @@ struct NormOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif @@ -239,10 +253,14 @@ struct NormOps { } }; -template +// This accumulator template is used to calculate the order zero norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormZeroOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + (data==acc_t(0) ? acc_t(0) : acc_t(1)); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + (data == static_cast(0) ? static_cast(0) : static_cast(1)); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -259,16 +277,20 @@ struct NormZeroOps { #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template +// This accumulator template is used to calculate the order one norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormOneOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + std::abs(data); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + static_cast(std::abs(data)); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -284,16 +306,40 @@ struct NormOneOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template + +template +struct AbsSwitch {}; + +template +inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch s) { + return static_cast(data); +} + +template +inline C10_DEVICE acc_t abs_if_complex(std::complex data, AbsSwitch s) { + return static_cast(std::abs(data)); +} + +template +inline C10_DEVICE acc_t abs_if_complex(c10::complex data, AbsSwitch s) { + return static_cast(std::abs(data)); +} + +// This accumulator template is used to calculate the order two norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormTwoOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + data * data; + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + acc_t data_ = abs_if_complex(data, AbsSwitch()); + return acc + data_ * data_; } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -309,8 +355,8 @@ struct NormTwoOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index dc680c382ab3..c8eb3cc99a01 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -468,12 +468,20 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop auto win_length = win_lengthOpt.value_or(n_fft); const bool return_complex = return_complexOpt.value_or( self.is_complex() || (window.defined() && window.is_complex())); - TORCH_CHECK( - return_complexOpt.has_value() || return_complex, - "stft requires the return_complex parameter be explicitly " - "specified for real inputs. Use return_complex=True to return " - "a complex-valued tensor, or return_complex=True to return " - "a real-valued tensor with an extra complex dimension."); + if (!return_complex) { + TORCH_CHECK(return_complexOpt.has_value(), + "stft requires the return_complex parameter be given for real inputs." + "You should pass return_complex=True to opt-in to complex dtype returns " + "(which will be required in a future pytorch release). " + ); + + TORCH_WARN_ONCE( + "stft with return_complex=False is deprecated. In a future pytorch " + "release, stft will return complex tensors for all inputs, and " + "return_complex=False will raise an error.\n" + "Note: you can still call torch.view_as_real on the complex output to " + "recover the old return format."); + } if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) { std::ostringstream ss; diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index ddc3ca8c2b34..f3147bdf78aa 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -293,6 +293,10 @@ Tensor index(const Tensor & self, TensorList indices) { Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + for (auto& index: indices) { + at::assert_no_overlap(result, index); + } auto info = make_info(self, indices); auto iter = make_index_out_iterator(info, result); @@ -305,21 +309,24 @@ Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, } Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate, const bool unsafe) { - TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); - if (accumulate && self.device().type() == kCUDA) { - TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ", - value.device(), " for value tensor"); - index_put_accum_stub(self.device().type(), self, indices, value, unsafe); - return self; - } - + TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); if (at::has_internal_overlap(self) == MemOverlap::YES) { TORCH_WARN( "Use of index_put_ on expanded tensors is deprecated. " "Please clone() the tensor before performing this operation. " "This also applies to advanced indexing e.g. tensor[indices] = tensor"); } - at::assert_no_partial_overlap(self, value); + at::assert_no_overlap(self, value); + for (auto& index: indices) { + at::assert_no_overlap(self, index); + } + + if (accumulate && self.device().type() == kCUDA) { + TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ", + value.device(), " for value tensor"); + index_put_accum_stub(self.device().type(), self, indices, value, unsafe); + return self; + } auto info = make_info(self, indices); auto iter = make_index_put_iterator(info, value); @@ -339,6 +346,9 @@ Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Ten dim = maybe_wrap_dim(dim, self.dim()); TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")"); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, source); int64_t numIndices = index.numel(); if (source.dim() == 0 && numIndices != 1) { @@ -394,8 +404,8 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T "index_add_(): Number of indices should be equal to self.size(dim)"); at::assert_no_internal_overlap(self); - at::assert_no_partial_overlap(self, index); - at::assert_no_partial_overlap(self, source); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, source); auto index_contig = index.contiguous(); @@ -461,6 +471,87 @@ Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const T return self.clone(at::MemoryFormat::Preserve).index_add_(dim, index, source); } +// Check that indices fall within dimension array size +// Avoid redispatch call to min/max +template +static void check_indexarray_range( + const IndexType* indices, + int64_t n, + IndexType indexing_axis_dim) { + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + TORCH_CHECK( + 0 <= idx && idx < indexing_axis_dim, + "INDICES element is out of DATA bounds, id=", + idx, + " axis_dim=", + indexing_axis_dim); + } +} + +Tensor & index_select_out_cpu_dim1_( + Tensor & result_contig, const Tensor & self, const Tensor & index_contig) { + + auto self_contig = self.contiguous(); + const caffe2::TypeMeta dataType = self_contig.dtype(); + size_t item_bytesize = dataType.itemsize(); + + auto out = static_cast(result_contig.data_ptr()); + + auto src_base = static_cast(self_contig.data_ptr()); + + auto self_sizes = self_contig.sizes(); + auto outer_dims_product = c10::size_to_dim_(1, self_sizes); + auto block_size = c10::size_from_dim_(2, self_sizes); + auto block_bytesize = block_size * item_bytesize; + + auto src_indexing_axis_dim = self_sizes[1]; + auto src_batch_bytesize = self_sizes[1] * block_bytesize; + auto N = index_contig.numel(); + + auto gathered_batch_bytesize = N * block_bytesize; + + AT_DISPATCH_INDEX_TYPES( + index_contig.scalar_type(), "batch_index_select_compute", [&]() { + + const auto* idxs = index_contig.data_ptr(); + check_indexarray_range(idxs, N, src_indexing_axis_dim); + + // Special-case single-float copy for efficiency + if (self.scalar_type() == ScalarType::Float && block_size == 1) { + for (auto batch = 0; batch < outer_dims_product; ++batch) { + const float* src_floats = + (const float*)(src_base + batch * src_batch_bytesize); + float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); + + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } + dst_floats[i] = src_floats[idx]; + } + } + } else { + // outer_dims_product specifies how many times we repeat inner dimensions, + // so we just iterate over it to cover all outer dimensions. + for (auto batch = 0; batch < outer_dims_product; ++batch) { + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } + + auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; + auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; + memcpy(dst, src, block_bytesize); + } + } + } + }); + return result_contig; +} + Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index) { dim = maybe_wrap_dim(dim, self.dim()); @@ -472,6 +563,8 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor"); at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + at::assert_no_overlap(result, index); auto result_size = self.sizes().vec(); if (self.dim() > 0) { @@ -486,6 +579,11 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim return result; } + if (dim == 1 && result.is_contiguous()) { + // fast pass + return index_select_out_cpu_dim1_(result, self, index_contig); + } + auto selfSlice = self.select(dim, 0); auto resultSlice = result.select(dim, 0); auto selfSlice_data = selfSlice.data_ptr(); @@ -608,6 +706,9 @@ Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & gather_out_cpu_cuda(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) { result.resize_(index.sizes()); + at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + at::assert_no_partial_overlap(result, index); gather_stub(result.device().type(), result, self, dim, index); return result; } @@ -627,6 +728,9 @@ Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, cons Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, source); + at::assert_no_overlap(self, index); scatter_stub(self.device().type(), self, dim, index, source); return self; } @@ -634,6 +738,8 @@ Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor Tensor & scatter_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar source) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); scatter_fill_stub(self.device().type(), self, dim, index, source); return self; } @@ -657,6 +763,8 @@ Tensor& scatter_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& in "scatter_(): Expected dtype int64 for index."); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "scatter_(): Expected floating or complex type for self."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); SCATTER_GATHER_OP op = get_operator_enum(reduce); scatter_scalar_reduce_stub(self.device().type(), self, dim, index, value, op); return self; @@ -668,6 +776,9 @@ Tensor & scatter_reduce_(Tensor & self, const int64_t dim, const Tensor & index, "scatter_(): Expected dtype int64 for index"); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "scatter_(): Expected floating or complex type for self."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, src); SCATTER_GATHER_OP op = get_operator_enum(reduce); scatter_reduce_stub(self.device().type(), self, dim, index, src, op); return self; @@ -684,6 +795,9 @@ Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar so Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, src); scatter_add_stub(self.device().type(), self, dim, index, src); return self; } @@ -780,8 +894,8 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, "masked_select(): self and result must have the same scalar type"); at::assert_no_internal_overlap(result); - at::assert_no_partial_overlap(result, self); - at::assert_no_partial_overlap(result, mask); + at::assert_no_overlap(result, self); + at::assert_no_overlap(result, mask); if (mask.dtype() == at::ScalarType::Byte) { TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \ @@ -895,6 +1009,9 @@ void take_out_cpu_template( auto index_continuous = index.contiguous(); bool is_contiguous = input.is_contiguous(); auto input_size = input.numel(); + at::assert_no_internal_overlap(output); + at::assert_no_partial_overlap(output, index); + at::assert_no_overlap(output, input); AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cpu", [&] { auto output_data = output_contiguous.data_ptr(); diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 48dab43b2dc8..f395c6956da5 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -14,15 +13,11 @@ bool is_same_size(const Tensor& self, const Tensor& other) { } int64_t size(const Tensor& self, int64_t dim) { - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - dim = maybe_wrap_dim(dim, self.dim(), false); - return self.sizes()[dim]; + return self.size(dim); } int64_t stride(const Tensor& self, int64_t dim) { - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - dim = maybe_wrap_dim(dim, self.dim(), false); - return self.strides()[dim]; + return self.stride(dim); } int64_t size(const Tensor& self, Dimname dim) { diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 900f5ee72f7a..9c91821aed80 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -250,12 +250,12 @@ Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(r Tensor exp(const Tensor& self) { return unary_op_impl(self, at::exp_out); } Tensor& exp_(Tensor& self) { return unary_op_impl_(self, at::exp_out); } -Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp2_stub); } -Tensor exp2(const Tensor& self) { return unary_op_impl(self, at::exp2_out); } +Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp2_stub); } +Tensor exp2(const Tensor& self) { return unary_op_impl_float(self, exp2_stub); } Tensor& exp2_(Tensor& self) { return unary_op_impl_(self, at::exp2_out); } -Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, expm1_stub); } -Tensor expm1(const Tensor& self) { return unary_op_impl(self, at::expm1_out); } +Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, expm1_stub); } +Tensor expm1(const Tensor& self) { return unary_op_impl_float(self, expm1_stub); } Tensor& expm1_(Tensor& self) { return unary_op_impl_(self, at::expm1_out); } Tensor& erf_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erf_stub); } @@ -347,8 +347,8 @@ Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_floa Tensor sinh(const Tensor& self) { return unary_op_impl_float(self, sinh_stub); } Tensor& sinh_(Tensor& self) { return unary_op_impl_(self, at::sinh_out); } -Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cosh_stub); } -Tensor cosh(const Tensor& self) { return unary_op_impl(self, at::cosh_out); } +Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cosh_stub); } +Tensor cosh(const Tensor& self) { return unary_op_impl_float(self, cosh_stub); } Tensor& cosh_(Tensor& self) { return unary_op_impl_(self, at::cosh_out); } Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acosh_stub); } diff --git a/aten/src/ATen/native/UpSampleNearest1d.cpp b/aten/src/ATen/native/UpSampleNearest1d.cpp index 45ac307ee4fc..b9dd52dffa5d 100644 --- a/aten/src/ATen/native/UpSampleNearest1d.cpp +++ b/aten/src/ATen/native/UpSampleNearest1d.cpp @@ -34,7 +34,9 @@ static std::array upsample_nearest1d_common_check(IntArrayRef input_ return {nbatch, channels, output_width}; } -TensorMeta upsample_nearest1d(const Tensor& input, IntArrayRef output_size, c10::optional scales) { +TORCH_META_FUNC(upsample_nearest1d) ( + const Tensor& input, IntArrayRef output_size, c10::optional scales +) { auto full_output_size = upsample_nearest1d_common_check(input.sizes(), output_size); // Allow for empty batch size but not other dimensions @@ -43,17 +45,19 @@ TensorMeta upsample_nearest1d(const Tensor& input, IntArrayRef output_size, c10: "Non-empty 3D data tensor expected but got a tensor with sizes ", input.sizes()); - return new_meta(input, full_output_size); + set_output(full_output_size, input.options()); } -TensorMeta upsample_nearest1d_backward(const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, c10::optional scales) { +TORCH_META_FUNC(upsample_nearest1d_backward) ( + const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, c10::optional scales +) { auto full_output_size = upsample_nearest1d_common_check(input_size, output_size); check_dim_size(grad_output, 3, 0, full_output_size[0]); check_dim_size(grad_output, 3, 1, full_output_size[1]); check_dim_size(grad_output, 3, 2, full_output_size[2]); - return new_meta(grad_output, input_size); + set_output(input_size, grad_output.options()); } } // namespace meta @@ -61,16 +65,15 @@ TensorMeta upsample_nearest1d_backward(const Tensor& grad_output, IntArrayRef ou namespace native { -Tensor& upsample_nearest1d_out_cpu( +TORCH_IMPL_FUNC(upsample_nearest1d_out_cpu) ( Tensor& output, const Tensor& input, IntArrayRef output_size, c10::optional scales) { upsample_nearest1d_kernel(kCPU, output, input, scales); - return output; } -Tensor& upsample_nearest1d_backward_out_cpu( +TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cpu) ( Tensor& grad_input, const Tensor& grad_output, IntArrayRef output_size, @@ -78,7 +81,6 @@ Tensor& upsample_nearest1d_backward_out_cpu( c10::optional scales) { grad_input.zero_(); upsample_nearest1d_backward_kernel(kCPU, grad_input, grad_output, scales); - return grad_input; } using at::native::upsample::compute_output_size; diff --git a/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp b/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp new file mode 100644 index 000000000000..b5ed77f6e400 --- /dev/null +++ b/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp @@ -0,0 +1,311 @@ +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +namespace { + +template +void cpu_adaptive_avg_pool( + Tensor& output_, + const Tensor& input_, + IntArrayRef output_size) { + auto input = input_.contiguous(); + auto output = output_.contiguous(); + + auto input_data = input.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t ndim = input.ndimension(); + // treat batch size and channels as one dimension + int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1); + int64_t input_height = input.size(-2); + int64_t input_width = input.size(-1); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + // parallel on dim of N, C + at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) { + for (int64_t c = begin; c < end; c++) { + scalar_t* input_ptr = input_data + c * input_height * input_width; + scalar_t* output_ptr = output_data + c * output_height * output_width; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + // compute local average + scalar_t sum = 0; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / kh / kw; + } + } + } + }); + + if (!output_.is_contiguous()) { + output_.copy_(output); + } +} + +template +void cpu_adaptive_avg_pool_channels_last( + Tensor& output_, + const Tensor& input_, + IntArrayRef output_size) { + auto memory_format = at::MemoryFormat::ChannelsLast; + auto input = input_.contiguous(memory_format); + auto output = output_.contiguous(memory_format); + + auto input_data = input.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t nbatch = input.size(0); + int64_t channels = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + using Vec = vec256::Vec256; + // parallel on dim N, H, W + at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t oh = 0; + int64_t ow = 0; + data_index_init(begin, n, nbatch, oh, output_height, ow, output_width); + + for (int64_t i = begin; i < end; i++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t* out = output_data + i * channels; + int64_t size = channels; + + // Note: For oridinary usage scenario, each out lane should + // fit in L1 cache; otherwise consider block dim C. + // Pass I: zero the out lane + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec out_vec = Vec(scalar_t(0)); + out_vec.store(out + d1); + } + for (; d1 < size; d1++) { + out[d1] = scalar_t(0); + } + // Pass II: compute local sum + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + scalar_t* in = input_data + n * input_height * input_width * channels + + ih * input_width * channels + iw * channels; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2); + out_vec.store(out + d2); + } + for (; d2 < size; d2++) { + out[d2] += in[d2]; + } + } + } + // Pass III: compute local average + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kh * kw)); + out_vec.store(out + d3); + } + for (; d3 < size; d3++) { + out[d3] = out[d3] / kh / kw; + } + + // move on to next output index + data_index_step(n, nbatch, oh, output_height, ow, output_width); + } + }); + + if (!output_.is_contiguous(memory_format)) { + output_.copy_(output); + } +} + +template +void cpu_adaptive_avg_pool_backward( + Tensor& grad_input_, + const Tensor& grad_output_) { + auto grad_output = grad_output_.contiguous(); + auto grad_input = grad_input_.contiguous(); + + auto grad_output_data = grad_output.data_ptr(); + auto grad_input_data = grad_input.data_ptr(); + + int64_t ndim = grad_output.ndimension(); + // treat batch size and channels as one dimension + int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + int64_t input_height = grad_input.size(-2); + int64_t input_width = grad_input.size(-1); + int64_t output_height = grad_output.size(-2); + int64_t output_width = grad_output.size(-1); + + // parallel on dim of N, C + at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) { + for (int64_t c = begin; c < end; c++) { + scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width; + scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t grad_delta = grad_output_ptr[oh * output_width + ow] / kh / kw; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + grad_input_ptr[ih * input_width + iw] += grad_delta; + } + } + } + } + } + }); + + if (!grad_input_.is_contiguous()) { + grad_input_.copy_(grad_input); + } +} + +template +void cpu_adaptive_avg_pool_backward_channels_last( + Tensor& grad_input_, + const Tensor& grad_output_) { + auto memory_format = at::MemoryFormat::ChannelsLast; + auto grad_input = grad_input_.contiguous(memory_format); + auto grad_output = grad_output_.contiguous(memory_format); + + auto grad_input_data = grad_input.data_ptr(); + auto grad_output_data = grad_output.data_ptr(); + + int64_t nbatch = grad_input.size(0); + int64_t channels = grad_input.size(1); + int64_t input_height = grad_input.size(2); + int64_t input_width = grad_input.size(3); + int64_t output_height = grad_output.size(2); + int64_t output_width = grad_output.size(3); + + using Vec = vec256::Vec256; + // parallel on dim N + at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) { + for (int64_t n = begin; n < end; n++) { + scalar_t* grad_input_ptr = grad_input_data + n * input_height * input_width * channels; + scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t* gout = grad_output_ptr + oh * output_width * channels + ow * channels; + int64_t size = channels; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + scalar_t* gin = grad_input_ptr + ih * input_width * channels + iw * channels; + + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kh * kw)); + gin_vec.store(gin + d); + } + for (; d < size; d++) { + gin[d] += gout[d] / kw / kw; + } + } + } + } + } + } + }); + + if (!grad_input_.is_contiguous(memory_format)) { + grad_input_.copy_(grad_input); + } +} + +void adaptive_avg_pool2d_kernel_impl( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + switch (input.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d", [&] { + cpu_adaptive_avg_pool(output, input, output_size); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{ + cpu_adaptive_avg_pool_channels_last(output, input, output_size); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void adapative_avg_pool2d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output) { + switch (grad_output.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "adaptive_avg_pool2d_backward", [&] { + cpu_adaptive_avg_pool_backward(grad_input, grad_output); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "adaptive_avg_pool2d_backward_channels_last", [&]{ + cpu_adaptive_avg_pool_backward_channels_last(grad_input, grad_output); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +} // anonymous namespace + +REGISTER_DISPATCH(adaptive_avg_pool2d_kernel, &adaptive_avg_pool2d_kernel_impl); +REGISTER_DISPATCH(adaptive_avg_pool2d_backward_kernel, &adapative_avg_pool2d_backward_kernel_impl); + +}} // at::native diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 36c01b2af49e..1792acffe57b 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -21,7 +21,7 @@ using namespace vec256; // Note: Undefined behavior when performing addition is intentionally // ignored. -void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { +void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) { if (iter.dtype() == ScalarType::Bool) { using scalar_t = bool; auto alpha = alpha_scalar.to(); diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 6ed9c798be23..10437f51d4b4 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -174,61 +174,75 @@ static void norm_kernel_tensor_iterator_impl( if (p.isIntegral(false)) { val = p.to(); } else if (p.isFloatingPoint()) { - val = p.to(); + val = p.to(); } else { AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float"); } - + // In the dispatch code blocks below, reduction kernels accumulate results as + // the type `acc_t`. When `scalar_t` is complex, `acc_t` is the downgraded + // real number type. Otherwise, `acc_t` and `scalar_t` are the same type. if (val == 0) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormZeroOps(), - scalar_t(0) + NormZeroOps(), + acc_t(0) ); }); } else if (val == 1) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormOneOps(), - scalar_t(0) + NormOneOps(), + acc_t(0) ); }); } else if (val == 2) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormTwoOps(), - scalar_t(0) + NormTwoOps(), + acc_t(0) ); }); } else if (val == INFINITY) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - AbsMaxOps(), - scalar_t(std::numeric_limits::min()) + AbsMaxOps(), + std::numeric_limits::min() ); }); } else if (val == -INFINITY) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - AbsMinOps(), - scalar_t(std::numeric_limits::max()) + AbsMinOps(), + std::numeric_limits::max() ); }); } else { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormOps { scalar_t(val) }, - scalar_t(0) + NormOps { acc_t(val) }, + acc_t(0) ); }); } + + // For complex outputs, the above kernels do not touch the imaginary values, + // so we must zero them out + if (isComplexType(iter.output().scalar_type())) { + at::imag(iter.output()).zero_(); + } } static void and_kernel_impl(TensorIterator& iter) { diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index 7d13de185509..1d69af7c5622 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -47,10 +47,10 @@ void _dim_apply( auto values_dim_stride = values.stride(dim); auto indices_dim_stride = indices.stride(dim); auto dim_size = values.size(dim); - + AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Bool, ScalarType::Half, iter.dtype(), - method_name, [&] { + "sorting_kernel_method_name", [&] { auto loop = [&](char** data, const int64_t* strides, int64_t n) { auto* values_data_bytes = data[0]; auto* indices_data_bytes = data[1]; @@ -68,7 +68,7 @@ void _dim_apply( indices_data_bytes += strides[1]; } }; - + iter.for_each(loop); } ); @@ -114,7 +114,7 @@ static void sort_kernel( auto composite_accessor = CompositeRandomAccessorCPU< decltype(values_accessor), decltype(indices_accessor) >(values_accessor, indices_accessor); - + if (descending) { std::sort(composite_accessor, composite_accessor + dim_size, KeyValueCompDesc()); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index b9653c7b25bf..b407eac4d280 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -183,7 +183,8 @@ static void _aminmax_kernel_impl( } static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, iter.dtype(), "where_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, + iter.dtype(), "where_cpu", [&] { if (condition_type == at::ScalarType::Byte) { cpu_kernel( iter, diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index bfb136776333..f7c4f9c34613 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -277,7 +277,7 @@ static void sign_kernel(TensorIterator& iter){ [=](scalar_t a) -> scalar_t { return (0 < a) - (a < 0); }, [=](Vec256 self_vec){ - // Comparision operators returns bitmask. + // Comparison operators returns bitmask. auto left = Vec256::blendv(zero_vec, one_vec, zero_vec < self_vec); auto right = Vec256::blendv(zero_vec, one_vec, self_vec < zero_vec); diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index aa6d57cdd2df..61e7877761d8 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -4,36 +4,12 @@ #include #include #include +#include namespace at { namespace native { namespace { -template -inline T data_index_init(T offset) { - return offset; -} - -template -inline T data_index_init(T offset, T &x, const T &X, Args &&... args) { - offset = data_index_init(offset, std::forward(args)...); - x = offset % X; - return offset / X; -} - -inline bool data_index_step() { - return true; -} - -template -inline bool data_index_step(T &x, const T &X, Args &&... args) { - if (data_index_step(std::forward(args)...)) { - x = ((x + 1) == X) ? 0 : (x + 1); - return x == 0; - } - return false; -} - static inline int64_t nearest_idx( int64_t output_index, int64_t input_size, diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h new file mode 100644 index 000000000000..32d1de5adb51 --- /dev/null +++ b/aten/src/ATen/native/cpu/utils.h @@ -0,0 +1,30 @@ +#pragma once + +namespace at { namespace native { namespace { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T &x, const T &X, Args &&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T &x, const T &X, Args &&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +}}} // namespace at::native:: diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index 4214b4dace19..388b04dba76a 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -317,16 +317,17 @@ __global__ void avg_pool3d_cuda_update_grad_input( } } -#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ +#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ avg_pool3d_cuda_update_output \ <<>>( \ - work_input.packed_accessor64(), \ - work_output.packed_accessor64(), \ + work_input.packed_accessor64(), \ + work_output.packed_accessor64(), \ kT, kH, \ dT, dH, dW, \ padT, padH, padW, \ count_include_pad, \ offsetZ, divisor); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ break void avg_pool3d_out_cuda_template( @@ -443,11 +444,10 @@ void avg_pool3d_out_cuda_template( padT, padH, padW, count_include_pad, offsetZ, divisor); - break; + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; } - AT_CUDA_CHECK(cudaGetLastError()); - totalZ -= 65535; offsetZ += 65535; } @@ -581,8 +581,7 @@ void avg_pool3d_backward_out_cuda_template( kT, kH, kW, 1.0f/divide_factor, offsetZ); - - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; @@ -614,6 +613,7 @@ void avg_pool3d_backward_out_cuda_template( padT, padH, padW, count_include_pad, offsetZ, divisor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { avg_pool3d_cuda_update_grad_input @@ -625,10 +625,9 @@ void avg_pool3d_backward_out_cuda_template( padT, padH, padW, count_include_pad, offsetZ, divisor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); - totalZ -= 65535; offsetZ += 65535; } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 5b16adaa2e5f..f49ddd288eb7 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -1941,7 +1942,8 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc * 2. return CPU tensors (because this is what magmaEig returns), which will be copied to GPU memory * by the caller */ -static std::tuple eig_cuda_helper(const Tensor& self, int64_t n, bool eigenvectors) { +std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvectors) { + int64_t n = self.size(-1); // copy self to pinned CPU memory auto self_working_copy = at::empty_strided( {n, n}, // square matrix @@ -1965,48 +1967,7 @@ static std::tuple eig_cuda_helper(const Tensor& self, int64_t n, return std::tuple(out_eigvals, out_eigvecs); } -std::tuple eig_cuda_out(Tensor& e, Tensor& v, const Tensor& self, bool eigenvectors) { - TORCH_CHECK(self.dim() == 2, "Expected a two-dimensional input but got ", self.dim(), " dimensions"); - TORCH_CHECK(e.dtype() == self.dtype(), "Expected 'e' to have dtype ", self.dtype(), " but got ", e.dtype()); - if (eigenvectors) - TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype()); - squareCheckInputs(self); - int64_t n = self.size(-1); - - at::native::resize_output(e, {n, 2}); - if (eigenvectors) { - at::native::resize_output(v, self.sizes()); - } - - // optimization: if self is empty, we can immediately return the empty - // GPU tensors, instead of getting empty CPU tensors from eig_cuda_helper - // and copying them to GPU - if (self.numel() == 0) { - return std::tuple(e, v); - } - - Tensor cpu_vals, cpu_vecs; - std::tie(cpu_vals, cpu_vecs) = eig_cuda_helper(self, n, eigenvectors); - e.copy_(cpu_vals); - if (eigenvectors) { - v.copy_(cpu_vecs); - } - return std::tuple(e, v); -} - -std::tuple eig_cuda(const Tensor& self, bool eigenvectors) { - TORCH_CHECK(self.dim() == 2, "Expected a two-dimensional input but got ", self.dim(), " dimensions"); - squareCheckInputs(self); - int64_t n = self.size(-1); - - Tensor e, v; - e = at::empty({n, 2}, self.options()); - if (eigenvectors) { - v = at::empty({n, n}, self.options()); - } - eig_cuda_out(e, v, self, eigenvectors); - return std::tuple(e, v); -} +REGISTER_DISPATCH(eig_stub, &eig_kernel_impl); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu index 864fb0a848df..bbc85f7997e4 100644 --- a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu @@ -18,7 +18,7 @@ struct AddFunctor { scalar_t alpha; }; -void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { +void add_kernel_cuda(TensorIteratorBase& iter, Scalar alpha_scalar) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { AddFunctor f(alpha_scalar.to()); gpu_kernel_with_scalars(iter, f); diff --git a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu index ed7e2190f75e..a385aa721522 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu @@ -16,10 +16,8 @@ namespace native { void sigmoid_backward_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "sigmoid_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "sigmoid_backward_cuda", [&] { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t(1.) - b) * b; - }); + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t(1.) - b) * b; }); }); } @@ -31,31 +29,29 @@ void logit_backward_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) { iter.dtype(), "logit_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "logit_cuda", [&] { - using T_ACC = acc_type; - const T_ACC eps = eps_scalar.to(); - if (eps < T_ACC(0)) { - gpu_kernel( - iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - const T_ACC dy_acc = static_cast(dy); - const T_ACC x_acc = static_cast(x); - return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) - ? std::numeric_limits::quiet_NaN() - : dy_acc / (x_acc * (T_ACC(1) - x_acc)); - }); - } else { - const T_ACC lo = eps; - const T_ACC hi = T_ACC(1) - eps; - gpu_kernel( - iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - const T_ACC dy_acc = static_cast(dy); - const T_ACC x_acc = static_cast(x); - return (x_acc < lo || x_acc > hi) - ? T_ACC(0) - : dy_acc / (x_acc * (T_ACC(1) - x_acc)); - }); - } - }); + using T_ACC = acc_type; + const T_ACC eps = eps_scalar.to(); + if (eps < T_ACC(0)) { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) + ? std::numeric_limits::quiet_NaN() + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } else { + const T_ACC lo = eps; + const T_ACC hi = T_ACC(1) - eps; + gpu_kernel( + iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < lo || x_acc > hi) + ? T_ACC(0) + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } }); } @@ -68,10 +64,8 @@ void tanh_backward_kernel_cuda(TensorIterator& iter) { }); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] { - gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t{1.} - b * b); - }); + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t{1.} - b * b); }); }); } diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 91401e994ebd..d11a5bb074c5 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -101,9 +101,11 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t switch (vec_size) { case 4: vectorized_elementwise_kernel<4, func_t, array_t><<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: vectorized_elementwise_kernel<2, func_t, array_t><<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 1: { auto input_calc = TrivialOffsetCalculator(); @@ -111,12 +113,12 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); unrolled_elementwise_kernel<<>>(N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); } - AT_CUDA_CHECK(cudaGetLastError()); } template @@ -127,7 +129,7 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); unrolled_elementwise_kernel<<>>(N, f, data, ic, oc, l, s); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h index 8cfc6c10f1ba..1b4f228bf229 100644 --- a/aten/src/ATen/native/cuda/DistributionTemplates.h +++ b/aten/src/ATen/native/cuda/DistributionTemplates.h @@ -155,6 +155,7 @@ void distribution_nullary_kernel(at::TensorIterator& iter, *out = transform_func(rand); } ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto offset_calc = make_offset_calculator<1>(iter); distribution_elementwise_grid_stride_kernel<<>>( @@ -167,8 +168,8 @@ void distribution_nullary_kernel(at::TensorIterator& iter, *out = transform_func(rand); } ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); } // Binary kernel @@ -260,10 +261,12 @@ void distribution_binary_kernel(TensorIterator &iter, PhiloxCudaState philox_arg distribution_binary_elementwise_kernel<<>>( numel, f, philox_args, output_data, input_data_1, input_data_2, TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { distribution_binary_elementwise_kernel<<>>( numel, f, philox_args, output_data, input_data_1, input_data_2, make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 5bed5532baee..651261cf6408 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -245,8 +245,8 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, max_indices.data_ptr(), grad.data_ptr(), grad_weight.data_ptr(), stride, numBags); C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); }); + }); return grad_weight; } diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index 2cd01d80bdca..88b952fe1d95 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -417,9 +417,9 @@ std::vector foreach_tensor_frac_cuda(TensorList tensors) { using opmath_t = get_opmath_t::opmath_t; multi_tensor_apply<2>(tensor_lists, UnaryOpFunctor(), + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), Trunc()); }); return tensor_lists[1]; @@ -439,10 +439,178 @@ void foreach_tensor_frac_cuda_(TensorList tensors) { using opmath_t = get_opmath_t::opmath_t; multi_tensor_apply<1>(tensor_lists, UnaryOpFunctor(), + /* depth */ 1, + /* r_args_depth */ 1, + /* res_arg_index */ 0>(), Trunc()); }); } + +template +struct Sigmoid { + T one = T(1); + __device__ T operator()(T t) const { return (one / (one + std::exp(-t))); } +}; + +std::vector foreach_tensor_sigmoid_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_sigmoid_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Sigmoid()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_sigmoid_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_sigmoid_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Sigmoid()); + }); +} + +template +struct Reciprocal { + T one = T(1); + __device__ T operator()(T t) const { return (one / t); } +}; + +std::vector foreach_tensor_reciprocal_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_reciprocal_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Reciprocal()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_reciprocal_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_reciprocal_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Reciprocal()); + }); +} + +template +struct Truncf { + __device__ T operator()(T t) const { return std::trunc(t); } +}; + +std::vector foreach_tensor_trunc_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_trunc_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Truncf()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_trunc_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_trunc_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Truncf()); + }); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu index bee3cfa4d436..41fc2dea5856 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu @@ -273,8 +273,8 @@ void fractional_max_pool2d_backward_out_cuda_template( <<>>( devGradInput, devGradOutput, devIndices); C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - ); + } + ); } }// namespace diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 7d7a59b32406..cb4aa644fee2 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace at { namespace native { @@ -229,6 +230,10 @@ void take_out_cuda_template(Tensor& output, const Tensor& input, const Tensor& i TORCH_CHECK(!(input.numel() == 0 && index.numel() != 0), "tried to take from an empty tensor"); + at::assert_no_internal_overlap(output); + at::assert_no_partial_overlap(output, index); + at::assert_no_overlap(output, input); + output.resize_(index.sizes()); AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cuda", [&] { diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 47527935fe73..d630d727019f 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -232,7 +232,6 @@ void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & va AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, value_.scalar_type(), "indexing_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "indexing_backward", [&] { indexing_backward_kernel<<>>( sorted_indices.data_ptr(), orig_indices.data_ptr(), @@ -243,8 +242,7 @@ void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & va strideBefore, nElemBefore); }); - }); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); if (permuted) self.copy_(src_.permute(inversePerm)); } @@ -446,6 +444,10 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const TORCH_CHECK(index.numel() == (source.dim() == 0 ? 1 : source.size(dim)), "index_add_(): Number of indices should be equal to self.size(dim)"); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, source); + // Scalars are treated as 1-d tensor Tensor self_ = (self.dim() == 0) ? self.view(1) : self; Tensor source_ = (source.dim() == 0) ? source.view(1) : source; @@ -476,21 +478,23 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ indexAddSmallIndex \ - <<>>( \ - selfInfo, sourceInfo, indexInfo, \ - selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ - SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ + SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ indexAddLargeIndex \ - <<>>( \ - selfInfo, sourceInfo, indexInfo, \ - selfAddDim, sourceAddDim, sourceTotalSize, \ - (IDX_IS_MAJOR) ? sliceSize : numIndex, \ - selfAddDimSize); + SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \ + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sourceTotalSize, \ + (IDX_IS_MAJOR) ? sliceSize : numIndex, \ + selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -502,77 +506,73 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const cuda::detail::canUse32BitIndexMath(source) && cuda::detail::canUse32BitIndexMath(index)) { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] { - cuda::detail::TensorInfo selfInfo = - cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); - selfInfo.reduceDim(selfAddDim); - AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { - auto sourceInfo = - cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); - sourceInfo.reduceDim(sourceAddDim); - - auto indexInfo = - cuda::detail::getTensorInfo(index); - indexInfo.collapseDims(); - - // A reasonable choice for when to have each thread iterate over - // index to choose - if (numIndex <= 16) { - if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); - } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); - } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + cuda::detail::TensorInfo selfInfo = + cuda::detail::getTensorInfo(self_); + int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + auto sourceInfo = + cuda::detail::getTensorInfo(source_); + int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + auto indexInfo = + cuda::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // index to choose + if (numIndex <= 16) { + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + } else { + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + } + } else { + bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); + + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); } else { - SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); } - } else { - bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); - - if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); - } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); - } else { - LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); - } - } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); - } else { - LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); - } + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); } else { - LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); } - }); + } }); }); } else { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] { - cuda::detail::TensorInfo selfInfo = - cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); - selfInfo.reduceDim(selfAddDim); - - cuda::detail::TensorInfo sourceInfo = - cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); - sourceInfo.reduceDim(sourceAddDim); - - AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { - cuda::detail::TensorInfo indexInfo = - cuda::detail::getTensorInfo(index); - indexInfo.collapseDims(); - - LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); - }); + cuda::detail::TensorInfo selfInfo = + cuda::detail::getTensorInfo(self_); + int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + + cuda::detail::TensorInfo sourceInfo = + cuda::detail::getTensorInfo(source_); + int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + cuda::detail::TensorInfo indexInfo = + cuda::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); }); }); } @@ -725,22 +725,24 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ indexSelectSmallIndex \ - <<>>( \ - outInfo, selfInfo, indicesInfo, \ - outSelectDim, selfSelectDim, static_cast(sliceSize), \ - selfSelectDimSize); + <<>>( \ + outInfo, selfInfo, indicesInfo, \ + outSelectDim, selfSelectDim, static_cast(sliceSize), \ + selfSelectDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ - DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ + DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ indexSelectLargeIndex \ - <<>>( \ - outInfo, selfInfo, indicesInfo, \ - outSelectDim, selfSelectDim, static_cast(outTotalSize), \ - static_cast((IDX_IS_MAJOR) ? sliceSize : numIndices), \ - selfSelectDimSize); + DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR> \ + <<>>( \ + outInfo, selfInfo, indicesInfo, \ + outSelectDim, selfSelectDim, static_cast(outTotalSize), \ + static_cast((IDX_IS_MAJOR) ? sliceSize : numIndices), \ + selfSelectDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -824,22 +826,17 @@ Tensor& index_select_out_cuda(Tensor& out, const Tensor& self, int64_t dim, TORCH_CHECK(at::cuda::check_device({out, self, index}), "Input, output and indices must be on the current device"); at::assert_no_internal_overlap(out); + at::assert_no_overlap(out, self); + at::assert_no_overlap(out, index); dim = at::maybe_wrap_dim(dim, self); TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); -#if defined(__HIP_PLATFORM_HCC__) AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, out.scalar_type(), "index_select_cuda", [&] { index_select_out_cuda_impl(out, self, dim, index); }); -#else // __HIP_PLATFORM_HCC__ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::Bool, - out.scalar_type(), "index_select_cuda", - [&] { index_select_out_cuda_impl(out, self, dim, index); }); -#endif // __HIP_PLATFORM_HCC__ return out; } diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index 82765b2aeddb..e74debfb29be 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -152,6 +152,11 @@ void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { if (iter.is_cpu_scalar(1)) { AUnaryFunctor af(f, iter.scalar_value(1)); iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly const OptionalDeviceGuard device_guard(device_of(iter.tensor(1))); gpu_kernel(iter, af); } else if (iter.is_cpu_scalar(2)) { diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index a8779d3d97af..3d59617903b4 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -74,6 +74,7 @@ void renormRows(Tensor& t) { <<>>(t.data_ptr(), rows, cols); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -348,6 +349,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n self_v.stride(0), self_v.stride(1) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // Generic, slow implementation with memory allocations @@ -399,12 +401,11 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n numDist, numCategories, prefixSum.data_ptr(), normDist.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } }); - AT_CUDA_CHECK(cudaGetLastError()); - if (inputSize == 1) { result.resize_({n_sample}); } diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 4830ca149cff..186099dfde50 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -5,26 +5,24 @@ namespace at { namespace native { std::tuple batch_norm_cuda_out(Tensor& output, Tensor& save_mean, Tensor& save_invstd, const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_cuda", [&] { - auto mean_st = running_mean.dtype(); - auto var_st = running_var.dtype(); - TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } else { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } + auto mean_st = running_mean.dtype(); + auto var_st = running_var.dtype(); + TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); } else { - if (is_half_float || is_bfloat16_float) { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } else { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); } - }); + } else { + if (is_half_float || is_bfloat16_float) { + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); + } else { + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); + } + } }); return std::tuple(output, save_mean, save_invstd); } @@ -54,38 +52,34 @@ std::tuple batch_norm_cuda(const Tensor& self, const Ten std::tuple batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array grad_input_mask) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_cuda", [&] { - auto mean_st = running_mean.dtype(); - auto var_st = running_var.dtype(); - TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } else { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } + auto mean_st = running_mean.dtype(); + auto var_st = running_var.dtype(); + TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); + } else { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); + } + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } else { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); } - }); + } }); } std::tuple batch_norm_stats_cuda(const Tensor& self, double epsilon) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_stats_cuda", [&] { - if (cuda::detail::canUse32BitIndexMath(self)) { - return batch_norm_stats_cuda_template(self, epsilon); - } else { - return batch_norm_stats_cuda_template(self, epsilon); - } - }); + if (cuda::detail::canUse32BitIndexMath(self)) { + return batch_norm_stats_cuda_template(self, epsilon); + } else { + return batch_norm_stats_cuda_template(self, epsilon); + } }); } @@ -99,26 +93,24 @@ Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Ten Tensor& batch_norm_elemt_cuda_out(Tensor& output, const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& mean, const Tensor& invstd, double epsilon) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_elemt", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_elemt", [&] { - auto mean_st = mean.dtype(); - auto invstd_st = invstd.dtype(); - TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } else { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); + } else { + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); + } + } else { + if (is_half_float || is_bfloat16_float) { + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); } else { - if (is_half_float || is_bfloat16_float) { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } else { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); } - }); + } }); return output; } @@ -137,95 +129,87 @@ std::tuple batch_norm_gather_stats_with_counts_cuda(const Tensor const Tensor& running_var, double momentum, double epsilon, const Tensor& counts) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, running_mean.scalar_type(), "batch_norm_update_stats_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_update_stats_cuda", [&] { - using accscalar_t = at::acc_type; - if (cuda::detail::canUse32BitIndexMath(self)) { - return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); - } else { - return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); - } - }); + using accscalar_t = at::acc_type; + if (cuda::detail::canUse32BitIndexMath(self)) { + return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); + } else { + return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); + } }); } std::tuple batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const Tensor& weight, bool input_g, bool weight_g, bool bias_g) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_reduce", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_reduce", [&] { - auto mean_st = mean.dtype(); - auto invstd_st = invstd.dtype(); - TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } else { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); + } else { + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } else { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); } - }); + } }); } Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_elemt", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_elemt", [&] { - auto mean_st = mean.dtype(); - auto invstd_st = invstd.dtype(); - TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } else { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } else { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); } - }); + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); + } else { + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); + } + } }); } std::tuple batch_norm_update_stats_cuda( const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward", [&] { - auto mean_st = running_mean.dtype(); - auto var_st = running_var.dtype(); - TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); - // Some workloads depend on passing in half input and float stats, which is - // usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this - // one so it needs to support the same case, or people start to complain. - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } else { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } + auto mean_st = running_mean.dtype(); + auto var_st = running_var.dtype(); + TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); + // Some workloads depend on passing in half input and float stats, which is + // usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this + // one so it needs to support the same case, or people start to complain. + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); + } else { + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); + } + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } else { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); } - }); + } }); } diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index a0d37dd44be1..8355ac004308 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -558,6 +558,7 @@ void batch_norm_cuda_template(Tensor& output_, Tensor& save_mean_, Tensor& save_ if (!train) { batch_norm_transform_input_kernel <<>> (input, output, running_mean, running_var, weight, bias, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in // the feature dimension, we'll use some threads for blocks @@ -566,10 +567,11 @@ void batch_norm_cuda_template(Tensor& output_, Tensor& save_mean_, Tensor& save_ dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); batch_norm_collect_statistics_kernel <<>> (input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); batch_norm_transform_input_kernel <<>> (input, output, save_mean, save_invstd, weight, bias, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); } template @@ -615,7 +617,7 @@ std::tuple batch_norm_backward_cuda_template(const Tenso batch_norm_backward_kernel <<>> (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(grad_input_, grad_weight_, grad_bias_); } @@ -654,7 +656,7 @@ std::tuple batch_norm_stats_cuda_template(const Tensor& input_, dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); batch_norm_collect_statistics_kernel <<>> (input, epsilon, 0.0, dummy_mean, dummy_invstd, mean, invstd); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(mean_, invstd_); } @@ -694,7 +696,7 @@ void batch_norm_elemt_cuda_template(Tensor& output_, const Tensor& input_, const dim3 threads_trans(tf, tb); batch_norm_transform_input_kernel <<>> (input, output, mean, invstd, weight, bias, epsilon); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -727,7 +729,7 @@ std::tuple batch_norm_gather_stats_cuda_template(const Tensor& m int grid = std::max(1, features/block); batch_norm_reduce_statistics_kernel <<>> (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(save_mean_, save_invstd_); } @@ -777,7 +779,7 @@ std::tuple batch_norm_backward_reduce_cuda_templ batch_norm_backward_reduce_kernel <<>> (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); } @@ -819,7 +821,7 @@ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Te dim3 threads_trans(tf, tb); batch_norm_backward_elemt_kernel <<>> (input, grad_output, mean, invstd, weight, mean_dy, mean_dy_xmu, grad_input); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input_reshaped.view(input_.sizes()); } @@ -853,7 +855,7 @@ std::tuple batch_norm_update_stats_cuda_template( // NB: epsilon is unused by the Var transform, so we set it to 0 batch_norm_collect_statistics_kernel <<>> (input, 0., momentum, running_mean, running_var, save_mean, save_var); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(save_mean_, save_var_); } diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index 8265c5999376..051583a12a53 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -258,50 +258,24 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024 default: break; } @@ -333,53 +307,27 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024 default: break; } } } - diff --git a/aten/src/ATen/native/cuda/ROCmLoops.cuh b/aten/src/ATen/native/cuda/ROCmLoops.cuh index b5115c6dcdfb..c339364b5a02 100644 --- a/aten/src/ATen/native/cuda/ROCmLoops.cuh +++ b/aten/src/ATen/native/cuda/ROCmLoops.cuh @@ -134,7 +134,7 @@ static void launch_kernel(int64_t N, const func_t& f) { dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -296,7 +296,7 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) { int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f, data); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template::value, int> = 0> diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 4286f05111b6..107c3c28fdac 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -39,8 +39,10 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) { using scalar_t = typename function_traits::result_type; if (N <= std::numeric_limits::max()) { elementwise_kernel_with_index<<>>(N, f, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { elementwise_kernel_with_index<<>>(N, f, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -105,7 +107,6 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, c10::optiona result.copy_(r); } - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -164,7 +165,6 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, c10::optiona result.copy_(r); } - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -201,7 +201,6 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { }); - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -263,7 +262,6 @@ Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { } }); - AT_CUDA_CHECK(cudaGetLastError()); return result; } diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 618088cefb3a..ea797e6011af 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -817,15 +817,16 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) switch(config.output_vec_size) { case 4: reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; default: reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - - AT_CUDA_CHECK(cudaGetLastError()); } class AccumulationBuffer { @@ -872,7 +873,7 @@ int get_output_vec_size(TensorIterator &iter) { vec_size /= 2; } }; - + uint64_t base_address = reinterpret_cast(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t); update_vec_size(base_address); diff --git a/aten/src/ATen/native/cuda/ReduceNormKernel.cu b/aten/src/ATen/native/cuda/ReduceNormKernel.cu index 39a355a96756..3953f16b69c9 100644 --- a/aten/src/ATen/native/cuda/ReduceNormKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceNormKernel.cu @@ -7,48 +7,49 @@ namespace at { namespace native { -template +// This reduction accumulates results as the type `acc_t`. By default, when +// `scalar_t` is complex, `acc_t` is the downgraded real number type. +// Otherwise, `acc_t` and `scalar_t` are the same type. +template ::type, typename out_t=typename scalar_value_type::type> void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) { - float p; + double p; if (val.isIntegral(false)) { p = val.to(); } else if (val.isFloatingPoint()) { - p = val.to(); + p = val.to(); } else { AT_ERROR("norm_kernel_cuda_impl expects norm to be integer or float"); } - if (p == static_cast(0)) { - gpu_reduce_kernel(iter, NormZeroOps(), 0); - } else if (p == static_cast(1)) { - gpu_reduce_kernel(iter, NormOneOps(), 0); - } else if (p == static_cast(2)) { - gpu_reduce_kernel(iter, NormTwoOps(), 0); - } else if (p == static_cast(INFINITY)) { - gpu_reduce_kernel(iter, AbsMaxOps(), std::numeric_limits::min()); - } else if (p == static_cast(-INFINITY)) { - gpu_reduce_kernel(iter, AbsMinOps(), std::numeric_limits::max()); + if (p == static_cast(0)) { + gpu_reduce_kernel(iter, NormZeroOps(), 0); + } else if (p == static_cast(1)) { + gpu_reduce_kernel(iter, NormOneOps(), 0); + } else if (p == static_cast(2)) { + gpu_reduce_kernel(iter, NormTwoOps(), 0); + } else if (p == static_cast(INFINITY)) { + gpu_reduce_kernel(iter, AbsMaxOps(), std::numeric_limits::min()); + } else if (p == static_cast(-INFINITY)) { + gpu_reduce_kernel(iter, AbsMinOps(), std::numeric_limits::max()); } else { - gpu_reduce_kernel(iter, NormOps{ acc_t(p) }, 0); + gpu_reduce_kernel(iter, NormOps{ acc_t(p) }, 0); } } static void norm_kernel_cuda(TensorIterator& iter, Scalar p) { - if (iter.dtype() == kHalf) { + if (iter.input_dtype() == kHalf) { return norm_kernel_cuda_impl(iter, p); - } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { + } else if (iter.dtype(1) == kHalf && iter.input_dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return norm_kernel_cuda_impl(iter, p); } - #ifdef __HIP_PLATFORM_HCC__ - else if(iter.dtype() == kBFloat16) { + else if(iter.input_dtype() == kBFloat16) { return norm_kernel_cuda_impl(iter, p); - } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { + } else if (iter.dtype(1) == kBFloat16 && iter.input_dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return norm_kernel_cuda_impl(iter, p); } - #endif - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] { norm_kernel_cuda_impl(iter, p); }); } diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu index 2b182f32b5e7..95a6825d507f 100644 --- a/aten/src/ATen/native/cuda/ReflectionPad.cu +++ b/aten/src/ATen/native/cuda/ReflectionPad.cu @@ -200,10 +200,9 @@ void reflection_pad1d_out_template( grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( input.data_ptr(), output.data_ptr(), input_w, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } void reflection_pad1d_backward_out_template( @@ -213,7 +212,7 @@ void reflection_pad1d_backward_out_template( if (grad_input.numel() == 0) { return; } - + TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); @@ -252,15 +251,14 @@ void reflection_pad1d_backward_out_template( grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( grad_input.data_ptr(), grad_output.data_ptr(), input_w, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } void reflection_pad2d_out_template( Tensor &output, const Tensor &input_, IntArrayRef padding) { - + TORCH_CHECK(canUse32BitIndexMath(input_), "input tensor must fit into 32-bit index math"); @@ -331,10 +329,9 @@ void reflection_pad2d_out_template( input.data_ptr(), output.data_ptr(), input_w, input_h, pad_t, pad_b, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } void reflection_pad2d_backward_out_template( @@ -344,7 +341,7 @@ void reflection_pad2d_backward_out_template( if (grad_input.numel() == 0) { return; } - + TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); TORCH_CHECK(canUse32BitIndexMath(grad_output_), @@ -393,10 +390,9 @@ void reflection_pad2d_backward_out_template( grad_input.data_ptr(), grad_output.data_ptr(), input_w, input_h, pad_t, pad_b, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index f70459928bf0..8437e80ebb48 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -23,6 +23,7 @@ static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *resu int64_t grid = std::min((size + warps_per_block - 1) / warps_per_block, 2048L); compute_cuda_kernel<<>>(repeat_ptr, cumsum_ptr, result_ptr, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } namespace at { namespace native { diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index b896a47afed9..8f164c8476f7 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -222,7 +222,7 @@ void replication_pad1d_out_cuda_template( (numInputDims == 3 && input.size(1) != 0 && input.size(2) != 0), "Expected 2D or 3D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", input.sizes()); - + if (numInputDims == 3) { numBatch = input.size(0); planeDim++; @@ -238,17 +238,17 @@ void replication_pad1d_out_cuda_template( " Calculated output W: ", outputW); if (numInputDims == 2) { - output.resize_({numPlanes, outputW}); + output.resize_({numPlanes, outputW}); } else { output.resize_({numBatch, numPlanes, outputW}); } - + if (input.numel() == 0) { return; } AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad1d_cuda", [&] { + input.scalar_type(), "replication_pad1d_cuda", [&] { if (numInputDims == 2) { auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); @@ -263,6 +263,7 @@ void replication_pad1d_out_cuda_template( replication_pad_forward_kernel1d <<>>(devInput, devOutput, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto devInput = input.packed_accessor64(); auto devOutput = output.packed_accessor64(); @@ -275,10 +276,10 @@ void replication_pad1d_out_cuda_template( replication_pad_forward_kernel1d <<>>(devInput, devOutput, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } ); - AT_CUDA_CHECK(cudaGetLastError()); } void replication_pad1d_backward_out_cuda_template( @@ -323,8 +324,8 @@ void replication_pad1d_backward_out_cuda_template( auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; if (numInputDims == 2) { - gradInput_ = gradInput.unsqueeze(0); - gradOutput_ = gradOutput.unsqueeze(0); + gradInput_ = gradInput.unsqueeze(0); + gradOutput_ = gradOutput.unsqueeze(0); } auto devGradInput = gradInput_.packed_accessor64(); auto devGradOutput = gradOutput_.packed_accessor64(); @@ -338,9 +339,8 @@ void replication_pad1d_backward_out_cuda_template( replication_pad_backward_kernel <<>>(devGradInput, devGradOutput, padL, padR); - } - ); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } void replication_pad2d_out_cuda_template( @@ -387,19 +387,17 @@ void replication_pad2d_out_cuda_template( " Calculated output H: ", outputH, " W: ", outputW); if (numInputDims == 3) { - output.resize_({numPlanes, outputH, outputW}); + output.resize_({numPlanes, outputH, outputW}); } else { output.resize_({numBatch, numPlanes, outputH, outputW}); } - + if (input.numel() == 0) { return; } AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad2d_cuda", [&] { - - + input.scalar_type(), "replication_pad2d_cuda", [&] { if (numInputDims == 3) { auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); @@ -415,6 +413,7 @@ void replication_pad2d_out_cuda_template( replication_pad_forward_kernel2d <<>>( devInput, devOutput, padT, padB, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto devInput = input.packed_accessor64(); auto devOutput = output.packed_accessor64(); @@ -428,10 +427,10 @@ void replication_pad2d_out_cuda_template( replication_pad_forward_kernel2d <<>>(devInput, devOutput, padT, padB, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } ); - AT_CUDA_CHECK(cudaGetLastError()); } void replication_pad2d_backward_out_cuda_template( @@ -499,9 +498,9 @@ void replication_pad2d_backward_out_cuda_template( replication_pad_backward_kernel <<>>(devGradInput, devGradOutput, padT, padB, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); } static inline void shapeCheck3d( @@ -650,10 +649,9 @@ void replication_pad3d_out_cuda_template( if (input.numel() == 0) { return; } - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad3d_cuda", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "replication_pad3d_cuda", [&] { if (numInputDims == 4) { auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); @@ -670,6 +668,7 @@ void replication_pad3d_out_cuda_template( replication_pad_forward_kernel3d <<>>( devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto devInput = input.packed_accessor64(); auto devOutput = output.packed_accessor64(); @@ -684,10 +683,10 @@ void replication_pad3d_out_cuda_template( replication_pad_forward_kernel3d <<>>( devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } ); - AT_CUDA_CHECK(cudaGetLastError()); } void replication_pad3d_backward_out_cuda_template( @@ -726,8 +725,7 @@ void replication_pad3d_backward_out_cuda_template( gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad3d_backward_cuda", [&] { - + input.scalar_type(), "replication_pad3d_backward_cuda", [&] { auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; if (numInputDims == 4) { @@ -747,9 +745,9 @@ void replication_pad3d_backward_out_cuda_template( replication_pad_backward_kernel <<>>( devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright); - } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } ); - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace @@ -795,7 +793,7 @@ Tensor replication_pad1d_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad1d_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); replication_pad1d_backward_out_cuda_template( gradInput, gradOutput, input, paddingSize); return gradInput; @@ -843,7 +841,7 @@ Tensor replication_pad2d_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad2d_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); replication_pad2d_backward_out_cuda_template( gradInput, gradOutput, input, paddingSize); return gradInput; @@ -891,7 +889,7 @@ Tensor replication_pad3d_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad3d_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); replication_pad3d_backward_out_cuda_template( gradInput, gradOutput, input, paddingSize); return gradInput; diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index 099512912203..384854505054 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -183,7 +183,7 @@ __host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Te tensor_kernel_scan_outer_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_orows, num_irows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -199,7 +199,7 @@ __host__ void scan_innermost_dim_with_indices(const Tensor& self, Tensor& values tensor_kernel_scan_innermost_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_rows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -436,7 +436,7 @@ __host__ void scan_outer_dim(const Tensor& self, Tensor& result, tensor_kernel_scan_outer_dim<<>>( result.data_ptr(), self.data_ptr(), num_orows, num_irows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -456,7 +456,7 @@ void scan_innermost_dim(const Tensor& self, Tensor& result, scalar_t init, Binar tensor_kernel_scan_innermost_dim<<>>( result.data_ptr(), self.data_ptr(), num_rows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -485,6 +485,7 @@ void scan_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction result.data_ptr() + i - 1, self.data_ptr() + i, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } size_t temp_storage_bytes = 0; AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan( diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 552384b45945..ff3b5bb08baa 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -72,11 +72,11 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) { return; } - dim3 block(nt); - dim3 grid((N + block.x * vt - 1) / (block.x * vt)); - auto stream = at::cuda::getCurrentCUDAStream(); + const dim3 block(nt); + const dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + const auto stream = at::cuda::getCurrentCUDAStream(); _scatter_gather_elementwise_kernel<<>>(N, f); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -192,7 +192,7 @@ struct cuda_scatter_gather_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_gather_base_kernel_func", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -264,7 +264,7 @@ struct cuda_scatter_gather_base_kernel { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_gather_base_kernel_reduce_multiply", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -365,7 +365,7 @@ struct cuda_scatter_fill_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_fill_base_kernel_func", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -417,7 +417,7 @@ struct cuda_scatter_fill_base_kernel { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_fill_base_kernel_reduce_multiply", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -494,5 +494,5 @@ REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cuda_kernel); REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cuda_kernel); REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cuda_kernel); REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cuda_kernel); - + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 7632438ba523..2831292845ec 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -294,7 +294,8 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, #define HANDLE_CASE(DIMS) \ HIP_CatArrayBatchedCopy<<<\ catGrid, applyBlock, 0, stream.stream()>>>(\ - data, d_inputs, outputParam, dimension, outputParam.tensorStride[dimension]); + data, d_inputs, outputParam, dimension, outputParam.tensorStride[dimension]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); switch (nDims) { case 1: HANDLE_CASE(1); @@ -310,7 +311,6 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, break; } #undef HANDLE_CASE - AT_CUDA_CHECK(cudaGetLastError()); } } @@ -404,7 +404,8 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, #define HANDLE_CASE(DIMS) \ CatArrayBatchedCopy<<<\ catGrid, applyBlock, 0, stream.stream()>>>(\ - data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]); + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); switch (nDims) { case 1: HANDLE_CASE(1); @@ -420,7 +421,6 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, break; } #undef HANDLE_CASE - AT_CUDA_CHECK(cudaGetLastError()); } } } // namespace diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index ca00a3520f29..fb43dcb4c3c3 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -709,32 +709,32 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t if (inner_size == 1) { dim3 grid(outer_size); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { - using accscalar_t = acc_type; - if (!half_to_float) { - if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { - dispatch_softmax_forward( - output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); - } else { - constexpr int ILP = sizeof(float4) / sizeof(scalar_t); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), dim_size - ); - } - } else { - if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { - dispatch_softmax_forward( - output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); + using accscalar_t = acc_type; + if (!half_to_float) { + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + dispatch_softmax_forward( + output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); + } else { + constexpr int ILP = sizeof(float4) / sizeof(scalar_t); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), dim_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } else { - constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), dim_size - ); + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + dispatch_softmax_forward( + output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); + } else { + constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), dim_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } - } }); // This kernel runs in a 2D grid, where each application along y dimension has a fixed // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. @@ -743,29 +743,28 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t uint32_t smem_size; dim3 grid, block; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { - using accscalar_t = acc_type; - if (!half_to_float) { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxForward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - cunn_SpatialSoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size - ); - } else { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxForward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - cunn_SpatialSoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size - ); - } + using accscalar_t = acc_type; + if (!half_to_float) { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + cunn_SpatialSoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + cunn_SpatialSoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } }); } - AT_CUDA_CHECK(cudaGetLastError()); } return output; } @@ -807,6 +806,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t <<>>( gI.data_ptr(), output.data_ptr(), grad.data_ptr(), dim_size ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } else { if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { @@ -819,6 +819,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t <<>>( gI.data_ptr(), output.data_ptr(), grad.data_ptr(), dim_size ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } }); @@ -826,33 +827,35 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t uint32_t smem_size; dim3 grid, block; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] { - using accscalar_t = acc_type; - if (!half_to_float) { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxBackward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - - cunn_SpatialSoftMaxBackward - <<>>( - gI.data_ptr(), output.data_ptr(), grad.data_ptr(), - outer_size, dim_size, inner_size - ); - } else { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxBackward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - - cunn_SpatialSoftMaxBackward - <<>>( - gI.data_ptr(), output.data_ptr(), grad.data_ptr(), - outer_size, dim_size, inner_size - ); - } + using accscalar_t = acc_type; + if (!half_to_float) { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxBackward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + + cunn_SpatialSoftMaxBackward + <<>>( + gI.data_ptr(), output.data_ptr(), grad.data_ptr(), + outer_size, dim_size, inner_size + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxBackward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + + cunn_SpatialSoftMaxBackward + <<>>( + gI.data_ptr(), output.data_ptr(), grad.data_ptr(), + outer_size, dim_size, inner_size + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } }); } - AT_CUDA_CHECK(cudaGetLastError()); + return gI; } } diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 59b07653593e..33fc4a18bffa 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -204,6 +204,7 @@ struct KthValueLauncher { self_info.strides[collapse_self_dim], values_info, indices_info); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }; @@ -238,6 +239,7 @@ struct MedianLauncher { num_slices, self_info.strides[collapse_self_dim], ignore_nan); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }; @@ -290,8 +292,6 @@ void kthvalue_cuda_template( values.squeeze_(dim); indices.squeeze_(dim); } - - AT_CUDA_CHECK(cudaGetLastError()); } std::tuple kthvalue_out_impl_cuda( @@ -371,8 +371,6 @@ std::tuple median_with_indices_impl( vals, inds, in, dim, MedianLauncher(ignore_nan)); } }); - - AT_CUDA_CHECK(cudaGetLastError()); } guard.reset(); diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu index 3ad0c06c69fc..db3e853a9321 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cu +++ b/aten/src/ATen/native/cuda/SpectralOps.cu @@ -125,6 +125,7 @@ void _fft_fill_with_conjugate_symmetry_cuda_( static_cast(in_data), input_offset_calculator, output_offset_calculator); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index 443bea3f71ac..b10ae52e44fd 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -17,7 +17,7 @@ DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub); namespace { void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] { if (condition_type == at::ScalarType::Byte) { gpu_kernel( iter, diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index a241f7df533c..effeef69f0cf 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -357,6 +357,7 @@ Tensor tril_indices_cuda( col, tril_size - rectangle_size, tril_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -434,6 +435,7 @@ Tensor triu_indices_cuda( col, rectangle_size, triu_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } diff --git a/aten/src/ATen/native/cuda/TriangularOps.cu b/aten/src/ATen/native/cuda/TriangularOps.cu index 6ba73e1c143e..8d497b5c94af 100644 --- a/aten/src/ATen/native/cuda/TriangularOps.cu +++ b/aten/src/ATen/native/cuda/TriangularOps.cu @@ -60,7 +60,7 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c int64_t N = self.numel(); dim3 dim_block = cuda::getApplyBlock(); dim3 dim_grid((N + dim_block.x - 1) / dim_block.x); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), name, [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu_tril_cuda_template", [&]{ if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) { auto result_info = cuda::detail::getTensorInfo(result); auto self_info = cuda::detail::getTensorInfo(self); diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 2488528f5e2c..867855217092 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -59,7 +59,7 @@ void sinh_kernel_cuda(TensorIterator& iter) { } void cosh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "cosh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "cosh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cosh(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 512154fd02df..4d676181be79 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -41,7 +41,7 @@ void exp_kernel_cuda(TensorIterator& iter) { } void exp2_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exp2_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::exp2(a); }); @@ -49,7 +49,7 @@ void exp2_kernel_cuda(TensorIterator& iter) { } void expm1_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "expm1_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "expm1_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::expm1(a); }); diff --git a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu index ef287ca592da..99488108ac26 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu @@ -196,22 +196,15 @@ static void upsample_nearest1d_backward_out_cuda_template( } // namespace -Tensor& upsample_nearest1d_out_cuda( +TORCH_IMPL_FUNC(upsample_nearest1d_out_cuda) ( Tensor& output, const Tensor& input, IntArrayRef output_size, c10::optional scales) { upsample_nearest1d_out_cuda_template(output, input, output_size, scales); - return output; -} - -Tensor upsample_nearest1d_cuda(const Tensor& input, IntArrayRef output_size, c10::optional scales) { - Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - upsample_nearest1d_out_cuda_template(output, input, output_size, scales); - return output; } -Tensor& upsample_nearest1d_backward_out_cuda( +TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cuda) ( Tensor& grad_input, const Tensor& grad_output, IntArrayRef output_size, @@ -219,18 +212,6 @@ Tensor& upsample_nearest1d_backward_out_cuda( c10::optional scales) { upsample_nearest1d_backward_out_cuda_template( grad_input, grad_output, output_size, input_size, scales); - return grad_input; -} - -Tensor upsample_nearest1d_backward_cuda( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales) { - Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - upsample_nearest1d_backward_out_cuda_template( - grad_input, grad_output, output_size, input_size, scales); - return grad_input; } using at::native::upsample::compute_output_size; diff --git a/aten/src/ATen/native/cuda/WeightNorm.cu b/aten/src/ATen/native/cuda/WeightNorm.cu index d90dc03007fd..8261eda01a3c 100644 --- a/aten/src/ATen/native/cuda/WeightNorm.cu +++ b/aten/src/ATen/native/cuda/WeightNorm.cu @@ -394,14 +394,14 @@ std::tuple weight_norm_cuda g.data_ptr(), fast_dim_size, slower_dims_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } // The kernel execution is asynchronous, so this will only catch errors on the kernel launch, // not the kernel's execution. Errors in kernel execution aren't guaranteed to be caught // until a later error check on a synchronizing CUDA call. Unfortunately, without manually - // synchronizing here, this is the best we can do. - AT_CUDA_CHECK(cudaGetLastError()); + // synchronizing here, the foregoing is the best we can do. return std::tuple{w, norms}; } @@ -486,14 +486,14 @@ std::tuple weight_norm_cuda_backward saved_norms.data_ptr(), fast_dim_size, slower_dims_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } // The kernel execution is asynchronous, so this will only catch errors on the kernel launch, // not the kernel's execution. Errors in kernel execution aren't guaranteed to be caught // until a later error check on a synchronizing CUDA call. Unfortunately, without manually - // synchronizing here, this is the best we can do. - AT_CUDA_CHECK(cudaGetLastError()); + // synchronizing here, the foregoing is the best we can do. return std::tuple{grad_v, grad_g}; } diff --git a/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp b/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp new file mode 100644 index 000000000000..bac8df92a5fc --- /dev/null +++ b/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp @@ -0,0 +1,147 @@ +#include // for the definition of AT_CUDNN_ENABLED +#include +#include + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// Placeholder operators +// +// --------------------------------------------------------------------- + +#if !AT_CUDNN_ENABLED() + +// See Note [ATen preprocessor philosophy] + +at::Tensor cudnn_convolution( + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_backward_input( + IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_backward_weight( + IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support"); +} + +std::tuple cudnn_convolution_backward( + const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_transpose( + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_transpose_backward_input( + const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_transpose_backward_weight( + IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support"); +} + +std::tuple cudnn_convolution_transpose_backward( + const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); +} + +void raw_cudnn_convolution_forward_out( + const Tensor& output, const Tensor& input, const Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("raw_cudnn_convolution_forward_out: ATen not compiled with cuDNN support"); +} + +void raw_cudnn_convolution_backward_input_out( + const at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("raw_cudnn_convolution_backward_input_out: ATen not compiled with cuDNN support"); +} + +void raw_cudnn_convolution_backward_weight_out( + const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("raw_cudnn_convolution_backward_weight_out: ATen not compiled with cuDNN support"); +} + +#endif // AT_CUDNN_ENABLED + +// --------------------------------------------------------------------- +// +// Deprecated operators +// +// --------------------------------------------------------------------- + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_deprecated( + const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) { + auto output = at::cudnn_convolution(input, weight, padding, stride, dilation, groups, benchmark, deterministic); + if (bias.defined()) { + output = output + reshape_bias(input.dim(), bias); + } + return output; +} + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_deprecated2( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); +} + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_transpose_deprecated( + const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + auto output = at::cudnn_convolution_transpose(input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + if (bias.defined()) { + output = output + reshape_bias(input.dim(), bias); + } + return output; +} + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_transpose_deprecated2( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); +} + +}} diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp new file mode 100644 index 000000000000..e360008e2707 --- /dev/null +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -0,0 +1,500 @@ +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() + +#include + +// NOTE [cuDNN API version] +// +// ConvPlaceholders.cpp contains placeholder implementation of cudnn +// convolution when cudnn is not enabled. These operators only raises +// errors, and do no real computation. This file also contains deprecated +// operators. These operators are implemented using currnet operators. +// +// cuDNN v7 and v8 have different API. ConvShared.{cpp, h} contains +// code shared by v7 and v8. Conv_v7.cpp contains implementation of +// convolution using cuDNN v7 API. Conv_v8.cpp contains implementation +// with v8 API. +// +// NOTE [ Convolution design ] +// +// cuDNN convolutions does not handle bias. Bias is handled outside. +// +// The general strategy: +// +// - cudnn_convolution (Tensor) +// Entry points for clients +// +// - cudnn_convolution_forward (TensorArg) +// Entry point, which may be reused between regular +// convolution and transposed convolution. +// +// - raw_cudnn_convolution_forward_out (Tensor) +// Function that has different implementation on Conv_v7.cpp +// and Conv_v8.cpp +// +// The raw API directly invokes CuDNN and are implemeted differently +// on cuDNN v7 and cuDNN v8 +// +// There are a few reasons this should never be directly exposed +// via ATen: +// +// - It takes output as a parameter (this should be computed!) +// - It doesn't do input checking +// - It doesn't resize output (it is assumed to be correctly sized) +// +// Where does argument checking happen? Here's the division of +// responsibility: +// - Things that happen in at::Tensor +// - TensorArg allocation +// - Things that happen in TensorArg +// - Check arguments (type, GPU, shape) + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// ConvolutionParams and ConvolutionArgs +// +// --------------------------------------------------------------------- + +std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params) { + out << "ConvolutionParams \n" + << " data_type = " << cudnnTypeToString(params.dataType) << "\n" + << " padding = " << ArrayRef{params.padding} << "\n" + << " stride = " << ArrayRef{params.stride} << "\n" + << " dilation = " << ArrayRef{params.dilation} << "\n" + << " groups = " << params.groups << "\n" + << " deterministic = " << (params.deterministic ? "true" : "false") << "\n" + << " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n"; + + return out; +} + +// NB: This can't be a constructor, because then ConvolutionParams +// would not be a POD anymore. +// TODO: Use TensorGeometry here instead of the entire Tensor, which we +// don't actually need. (OTOH: We can always pass in +// grad_input/grad_output, so this is not very pressing) +void setConvolutionParams( + ConvolutionParams* params, + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool deterministic, bool allow_tf32) { + + cudnnDataType_t dataType = getCudnnDataType(input); + memset(params, 0, sizeof(ConvolutionParams)); + params->dataType = dataType; + // ASSERT(weight.dim() == input.dim()) + for (int i = 0; i != input.dim(); ++i) { + params->input_size[i] = (int) input.size(i); + params->input_stride[i] = (int) input.stride(i); + params->weight_size[i] = (int) weight.size(i); + } + // ASSERT(padding.size() == stride.size()) + // ASSERT(padding.size() == dilation.size()) + for (size_t i = 0; i != padding.size(); ++i) { + params->padding[i] = padding[i]; + params->stride[i] = stride[i]; + params->dilation[i] = dilation[i]; + } + // In principle, we shouldn't parametrize by groups for legacy + // CuDNN, but it doesn't seem worth the effort to actually do this. + params->groups = groups; + params->deterministic = deterministic; + params->allow_tf32 = allow_tf32; +} + +std::string repro_from_args(const ConvolutionArgs& args) { + auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; }; + std::string partial_dtype; + switch (args.params.dataType) { + case CUDNN_DATA_FLOAT: partial_dtype = "float"; break; + case CUDNN_DATA_DOUBLE: partial_dtype = "double"; break; + case CUDNN_DATA_HALF: partial_dtype = "half"; break; + default: partial_dtype = "unsupported"; + } + const std::string full_dtype = "torch." + partial_dtype; + const int out_channels = args.weight.sizes()[0]; + const int in_channels = args.weight.sizes()[1] * args.params.groups; + const size_t dim = args.input.sizes().size(); + const std::string channels_last_xd = dim == 4 ? "channels_last" : "channels_last_3d"; + const std::string to_channels_last = args.input.suggest_memory_format() == at::MemoryFormat::ChannelsLast \ + ? ".to(memory_format=torch." + channels_last_xd + ")" : ""; + + std::ostringstream ss; + ss << "You can try to repro this exception using the following code snippet. "; + ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n"; + ss << "import torch\n"; + ss << "torch.backends.cuda.matmul.allow_tf32 = " << pybool(at::globalContext().allowTF32CuBLAS()) << "\n"; + ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n"; + ss << "torch.backends.cudnn.deterministic = " << pybool(args.params.deterministic) << "\n"; + ss << "torch.backends.cudnn.allow_tf32 = " << pybool(args.params.allow_tf32) << "\n"; + ss << "data = torch.randn(" << args.input.sizes() << ", dtype=" << full_dtype << ", "; + ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n"; + ss << "net = torch.nn.Conv" << dim-2 << "d(" << in_channels << ", " << out_channels << ", "; + ss << "kernel_size=" << args.weight.sizes().slice(2) << ", "; + ss << "padding=" << ArrayRef(args.params.padding, dim-2) << ", "; + ss << "stride=" << ArrayRef(args.params.stride, dim-2) << ", "; + ss << "dilation=" << ArrayRef(args.params.dilation, dim-2) << ", "; + ss << "groups=" << args.params.groups << ")\n"; + ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last << "\n"; + ss << "out = net(data)\n"; + ss << "out.backward(torch.randn_like(out))\n"; + ss << "torch.cuda.synchronize()\n\n"; + + return ss.str(); +} + +std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args) { + out << repro_from_args(args) // already has a trailing newline + << args.params // already has a trailing newline + << "input: " << args.idesc // already has a trailing newline + << "output: " << args.odesc // already has a trailing newline + << "weight: " << args.wdesc // already has a trailing newline + << "Pointer addresses: " << "\n" + << " input: " << args.input.data_ptr() << "\n" + << " output: " << args.output.data_ptr() << "\n" + << " weight: " << args.weight.data_ptr() << "\n"; + + return out; +} + +// --------------------------------------------------------------------- +// +// Checking +// +// --------------------------------------------------------------------- + +// Used on pad, stride and dilation +static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name) +{ + TORCH_CHECK(args.size() <= expected_size, + "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ", + expected_size, " (while checking arguments for ", c, ")"); + TORCH_CHECK(args.size() >= expected_size, + "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ", + expected_size, " (while checking arguments for ", c, ")"); + + auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;}); + if (num_negative_values > 0){ + std::stringstream ss; + ss << arg_name << " should be greater than zero but got ("; + std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); + ss << args.back() << ")" << " (while checking arguments for " << c << ")"; + AT_ERROR(ss.str()); + } +} + + +// NOTE [ Convolution checks ] +// +// NB: For many call sites, it is not strictly necessary to check all of +// these relationships (for example, for forward convolution, we compute +// the size of output ourselves, so we don't actually need to check +// output. However, writing a single function that does everything +// means we get to reuse it for both forwards and all backwards +// variants, even when the set of "real" inputs varies. The magic of +// relational computing! +// +// (There is one downside, which is that it is slightly harder to write +// error messages which are able to distinguish between real inputs +// (which the user can change) and computed inputs (which the user can +// only indirectly affect). It would be an interesting exercise to +// come up with a general framework to handle such situations.) +static void convolution_shape_check( + CheckedFrom c, + const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) +{ + check_args(c, padding, input->dim() - 2, "padding"); + check_args(c, stride, padding.size(), "stride"); + check_args(c, dilation, padding.size(), "dilation"); + + // Input + checkDimRange(c, input, 3, 6 /* exclusive */); + checkSize(c, input, input_channels_dim, weight->size(1) * groups); + + // Weight + checkSameDim(c, input, weight); + + // TODO: check that output->size() matches output_sizes + // TODO: check that weight matches output->sizes() + checkSameDim(c, input, output); +} + +// --------------------------------------------------------------------- +// +// Convolution forward / Transposed convolution backward +// +// --------------------------------------------------------------------- + +Tensor cudnn_convolution_forward( + CheckedFrom c, + const TensorArg& input, const TensorArg& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + checkAllSameType(c, {input, weight}); + checkAllSameGPU(c, {input, weight}); + + auto layout = cudnn_conv_use_channels_last(*input, *weight) ? + at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; + auto output_t = at::empty( + conv_output_size(input->sizes(), weight->sizes(), + padding, stride, dilation), + input->options(), + layout); + + if (output_t.numel() == 0) { + return output_t; + } + + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{ output_t, "result", 0 }; + convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); + + // See #4500 + Tensor weight_contig = weight->contiguous(layout); + // Make sure that NC11 strides follow formula + weight_contig.resize_(weight_contig.sizes(), layout); + Tensor input_contig = input->contiguous(layout); + input_contig.resize_(input_contig.sizes(), layout); + + raw_cudnn_convolution_forward_out( + *output, input_contig, weight_contig, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + + return *output; +} + +Tensor cudnn_convolution( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }; + CheckedFrom c = "cudnn_convolution"; + auto output_t = cudnn_convolution_forward( + c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return output_t; +} + +// NB: output_padding not needed here, as there is no ambiguity to +// resolve +Tensor cudnn_convolution_transpose_backward_input( + const Tensor& grad_output_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg grad_output { grad_output_t, "grad_output", 1 }, + weight { weight_t, "weight", 2 }; + return cudnn_convolution_forward( + "cudnn_convolution_transpose_backward_input", + grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +std::tuple cudnn_convolution_transpose_backward( + const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); + + Tensor grad_input, grad_weight; + if (output_mask[0]) { + grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + if (output_mask[1]) { + grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + + return std::tuple{grad_input, grad_weight}; +} + +// --------------------------------------------------------------------- +// +// Convolution backward / Transposed convolution forward +// +// --------------------------------------------------------------------- + +// NOTE [ Backward vs transpose convolutions ] +// +// Backward and transpose are algorithmically equivalent, but they +// compute their geometry differently. In a backwards, you knew what +// the original size of the input tensor was, so you can cache that +// geometry and fill it directly. In transposed convolution, it is +// more conventional to not explicitly specify the output (previously +// input) size, and compute it. This, however, leaves a degree of +// freedom; this degree of freedom is resolved using the +// output_padding parameter. Both of these interfaces are equivalent, +// but they are differently convenient depending on the use case. + +Tensor cudnn_convolution_backward_input( + CheckedFrom c, + IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + checkAllSameType(c, {grad_output, weight}); + checkAllSameGPU(c, {grad_output, weight}); + + auto layout = cudnn_conv_use_channels_last(*grad_output, *weight) ? + at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; + auto grad_input_t = at::empty(input_size, grad_output->options(), layout); + + // Avoid "grad_input" when this is being used as transposed convolution + TensorArg grad_input{ grad_input_t, "result", 0 }; + convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); + + // See #4500 + Tensor weight_contig = weight->contiguous(layout); + // Make sure that NC11 strides follow formula + weight_contig.resize_(weight_contig.sizes(), layout); + + Tensor grad_output_contig = grad_output->contiguous(layout); + grad_output_contig.resize_(grad_output_contig.sizes(), layout); + + raw_cudnn_convolution_backward_input_out( + *grad_input, grad_output_contig, weight_contig, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + + return *grad_input; +} + +Tensor cudnn_convolution_transpose_forward( + CheckedFrom c, + const TensorArg& grad_output, const TensorArg& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), + padding, output_padding, stride, dilation, groups); + return cudnn_convolution_backward_input(c, input_size, grad_output, weight, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +Tensor cudnn_convolution_backward_input( + IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }, + weight{ weight_t, "weight", 2 }; + return cudnn_convolution_backward_input( + "cudnn_convolution_backward_input", + input_size, grad_output, weight, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +std::tuple cudnn_convolution_backward( + const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); + + Tensor grad_input, grad_weight; + if (input.numel() == 0) { + if (output_mask[0]) { + grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (output_mask[1]) { + grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + } else { + if (output_mask[0]) { + grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + if (output_mask[1]) { + grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + } + + return std::tuple{grad_input, grad_weight}; +} + +Tensor cudnn_convolution_transpose( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }; + CheckedFrom c = "cudnn_convolution_transpose"; + auto output_t = cudnn_convolution_transpose_forward( + c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return output_t; +} + +// --------------------------------------------------------------------- +// +// Convolution backward (weight) +// +// --------------------------------------------------------------------- + +Tensor cudnn_convolution_backward_weight( + CheckedFrom c, + IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ? + at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; + + Tensor grad_output_contig_t = grad_output_t.contiguous(layout); + // Make sure that NC11 strides follow formula + grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout); + TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 }; + + Tensor input_contig_t = input_t.contiguous(layout); + input_contig_t.resize_(input_contig_t.sizes(), layout); + TensorArg input{ input_contig_t, "input", 2}; + + checkAllSameType(c, {grad_output_contig, input}); + checkAllSameGPU(c, {grad_output_contig, input}); + + auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), layout); + + // For uniformity with everything else, although it seems grad_weight + // would be unambiguous too. + TensorArg grad_weight{ grad_weight_t, "result", 0 }; + convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups); + + raw_cudnn_convolution_backward_weight_out( + *grad_weight, *grad_output_contig, *input, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + + return grad_weight_t; +} + +Tensor cudnn_convolution_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + return cudnn_convolution_backward_weight( + "cudnn_convolution_backward_weight", + weight_size, grad_output_t, input_t, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +Tensor cudnn_convolution_transpose_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + return cudnn_convolution_backward_weight( + "cudnn_convolution_backward_weight", + weight_size, input_t, grad_output_t, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +}} + +#endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h new file mode 100644 index 000000000000..e30b5c7be581 --- /dev/null +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -0,0 +1,88 @@ +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// Helper classes +// +// --------------------------------------------------------------------- + +// This POD struct is used to let us easily compute hashes of the +// parameters +struct ConvolutionParams +{ + cudnnDataType_t dataType; + int input_size[2 + max_dim]; + int input_stride[2 + max_dim]; + int weight_size[2 + max_dim]; + int padding[max_dim]; + int stride[max_dim]; + int dilation[max_dim]; + int64_t groups; + bool deterministic; + bool allow_tf32; + // NB: transposed purposely omitted: transposed just swaps + // forward and backward, so you can reuse the benchmark entry, +}; + +// Convenience struct for passing around descriptors and data +// pointers +struct ConvolutionArgs { + cudnnHandle_t handle; + ConvolutionParams params; + TensorDescriptor idesc, odesc; + FilterDescriptor wdesc; + const Tensor& input, output, weight; + ConvolutionDescriptor cdesc; + + ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { + } +}; + +std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params); + +// NB: This can't be a constructor, because then ConvolutionParams +// would not be a POD anymore. +// TODO: Use TensorGeometry here instead of the entire Tensor, which we +// don't actually need. (OTOH: We can always pass in +// grad_input/grad_output, so this is not very pressing) +void setConvolutionParams( + ConvolutionParams* params, + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool deterministic, bool allow_tf32); + +std::string repro_from_args(const ConvolutionArgs& args); + +std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args); + +// --------------------------------------------------------------------- +// +// Raw functions +// +// --------------------------------------------------------------------- + +void raw_cudnn_convolution_forward_out( + const Tensor& output, const Tensor& input, const Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32); + +void raw_cudnn_convolution_backward_input_out( + const at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32); + +void raw_cudnn_convolution_backward_weight_out( + const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32); + +}} diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp similarity index 54% rename from aten/src/ATen/native/cudnn/Conv.cpp rename to aten/src/ATen/native/cudnn/Conv_v7.cpp index 4524af2fe244..5e1f124f1185 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -1,3 +1,7 @@ +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() + #include #include #include @@ -5,80 +9,10 @@ #include #include #include -#include #include -#include - -#if !AT_CUDNN_ENABLED() - -namespace at { namespace native { - -// See Note [ATen preprocessor philosophy] - -at::Tensor cudnn_convolution( - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_backward_input( - IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_backward_weight( - IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support"); -} - -std::tuple cudnn_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_transpose( - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_transpose_backward_input( - const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_transpose_backward_weight( - IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support"); -} - -std::tuple cudnn_convolution_transpose_backward( - const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); -} - -}} - -#else // AT_CUDNN_ENABLED +#include #include - -#include -#include #include #include #include @@ -130,217 +64,6 @@ namespace at { namespace native { // TODO: Go through all the checking code again and make sure // we haven't missed anything. -// TODO: Move this into the standard library, with a better name? -Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) { - auto group_size = t.size(dim) / groups; - return t.narrow(dim, group_idx * group_size, group_size); -} - -// --------------------------------------------------------------------- -// -// Checking -// -// --------------------------------------------------------------------- - -// Note [Legacy CuDNN grouped convolution support] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// CuDNN earlier than CuDNN 7 does not directly support group -// convolution, so we provide support for it by sequentially -// running a convolution per group with appropriately -// adjusted sizes. https://blog.yani.io/filter-group-tutorial/ -// has a fairly good diagram explaining how it works. - -// Used on pad, stride and dilation -static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name) -{ - TORCH_CHECK(args.size() <= expected_size, - "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ", - expected_size, " (while checking arguments for ", c, ")"); - TORCH_CHECK(args.size() >= expected_size, - "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ", - expected_size, " (while checking arguments for ", c, ")"); - - auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;}); - if (num_negative_values > 0){ - std::stringstream ss; - ss << arg_name << " should be greater than zero but got ("; - std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); - ss << args.back() << ")" << " (while checking arguments for " << c << ")"; - AT_ERROR(ss.str()); - } -} - - -// NOTE [ Convolution checks ] -// -// NB: For many call sites, it is not strictly necessary to check all of -// these relationships (for example, for forward convolution, we compute -// the size of output ourselves, so we don't actually need to check -// output. However, writing a single function that does everything -// means we get to reuse it for both forwards and all backwards -// variants, even when the set of "real" inputs varies. The magic of -// relational computing! -// -// (There is one downside, which is that it is slightly harder to write -// error messages which are able to distinguish between real inputs -// (which the user can change) and computed inputs (which the user can -// only indirectly affect). It would be an interesting exercise to -// come up with a general framework to handle such situations.) -static void convolution_shape_check( - CheckedFrom c, - const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) -{ - check_args(c, padding, input->dim() - 2, "padding"); - check_args(c, stride, padding.size(), "stride"); - check_args(c, dilation, padding.size(), "dilation"); - - // Input - checkDimRange(c, input, 3, 6 /* exclusive */); - checkSize(c, input, input_channels_dim, weight->size(1) * groups); - - // Weight - checkSameDim(c, input, weight); - - // TODO: check that output->size() matches output_sizes - // TODO: check that weight matches output->sizes() - checkSameDim(c, input, output); -} - -// This POD struct is used to let us easily compute hashes of the -// parameters -struct ConvolutionParams -{ - cudnnDataType_t dataType; - int input_size[2 + max_dim]; - int input_stride[2 + max_dim]; - int weight_size[2 + max_dim]; - int padding[max_dim]; - int stride[max_dim]; - int dilation[max_dim]; - int64_t groups; - bool deterministic; - bool allow_tf32; - // NB: transposed purposely omitted: transposed just swaps - // forward and backward, so you can reuse the benchmark entry, -}; - -std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params) { - out << "ConvolutionParams \n" - << " data_type = " << cudnnTypeToString(params.dataType) << "\n" - << " padding = " << ArrayRef{params.padding} << "\n" - << " stride = " << ArrayRef{params.stride} << "\n" - << " dilation = " << ArrayRef{params.dilation} << "\n" - << " groups = " << params.groups << "\n" - << " deterministic = " << (params.deterministic ? "true" : "false") << "\n" - << " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n"; - - return out; -} - -// NB: This can't be a constructor, because then ConvolutionParams -// would not be a POD anymore. -// TODO: Use TensorGeometry here instead of the entire Tensor, which we -// don't actually need. (OTOH: We can always pass in -// grad_input/grad_output, so this is not very pressing) -void setConvolutionParams( - ConvolutionParams* params, - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool deterministic, bool allow_tf32) { - - cudnnDataType_t dataType = getCudnnDataType(input); - memset(params, 0, sizeof(ConvolutionParams)); - params->dataType = dataType; - // ASSERT(weight.dim() == input.dim()) - for (int i = 0; i != input.dim(); ++i) { - params->input_size[i] = (int) input.size(i); - params->input_stride[i] = (int) input.stride(i); - params->weight_size[i] = (int) weight.size(i); - } - // ASSERT(padding.size() == stride.size()) - // ASSERT(padding.size() == dilation.size()) - for (size_t i = 0; i != padding.size(); ++i) { - params->padding[i] = padding[i]; - params->stride[i] = stride[i]; - params->dilation[i] = dilation[i]; - } - // In principle, we shouldn't parametrize by groups for legacy - // CuDNN, but it doesn't seem worth the effort to actually do this. - params->groups = groups; - params->deterministic = deterministic; - params->allow_tf32 = allow_tf32; -} - -// Convenience struct for passing around descriptors and data -// pointers -struct ConvolutionArgs { - cudnnHandle_t handle; - ConvolutionParams params; - TensorDescriptor idesc, odesc; - FilterDescriptor wdesc; - const Tensor& input, output, weight; - ConvolutionDescriptor cdesc; - - ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { - } -}; - -std::string repro_from_args(const ConvolutionArgs& args) { - auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; }; - std::string partial_dtype; - switch (args.params.dataType) { - case CUDNN_DATA_FLOAT: partial_dtype = "float"; break; - case CUDNN_DATA_DOUBLE: partial_dtype = "double"; break; - case CUDNN_DATA_HALF: partial_dtype = "half"; break; - default: partial_dtype = "unsupported"; - } - const std::string full_dtype = "torch." + partial_dtype; - const int out_channels = args.weight.sizes()[0]; - const int in_channels = args.weight.sizes()[1] * args.params.groups; - const size_t dim = args.input.sizes().size(); - const std::string channels_last_xd = dim == 4 ? "channels_last" : "channels_last_3d"; - const std::string to_channels_last = args.input.suggest_memory_format() == at::MemoryFormat::ChannelsLast \ - ? ".to(memory_format=torch." + channels_last_xd + ")" : ""; - - std::ostringstream ss; - ss << "You can try to repro this exception using the following code snippet. "; - ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n"; - ss << "import torch\n"; - ss << "torch.backends.cuda.matmul.allow_tf32 = " << pybool(at::globalContext().allowTF32CuBLAS()) << "\n"; - ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n"; - ss << "torch.backends.cudnn.deterministic = " << pybool(args.params.deterministic) << "\n"; - ss << "torch.backends.cudnn.allow_tf32 = " << pybool(args.params.allow_tf32) << "\n"; - ss << "data = torch.randn(" << args.input.sizes() << ", dtype=" << full_dtype << ", "; - ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n"; - ss << "net = torch.nn.Conv" << dim-2 << "d(" << in_channels << ", " << out_channels << ", "; - ss << "kernel_size=" << args.weight.sizes().slice(2) << ", "; - ss << "padding=" << ArrayRef(args.params.padding, dim-2) << ", "; - ss << "stride=" << ArrayRef(args.params.stride, dim-2) << ", "; - ss << "dilation=" << ArrayRef(args.params.dilation, dim-2) << ", "; - ss << "groups=" << args.params.groups << ")\n"; - ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last << "\n"; - ss << "out = net(data)\n"; - ss << "out.backward(torch.randn_like(out))\n"; - ss << "torch.cuda.synchronize()\n\n"; - - return ss.str(); -} - -std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args) { - out << repro_from_args(args) // already has a trailing newline - << args.params // already has a trailing newline - << "input: " << args.idesc // already has a trailing newline - << "output: " << args.odesc // already has a trailing newline - << "weight: " << args.wdesc // already has a trailing newline - << "Pointer addresses: " << "\n" - << " input: " << args.input.data_ptr() << "\n" - << " output: " << args.output.data_ptr() << "\n" - << " weight: " << args.weight.data_ptr() << "\n"; - - return out; -} - // --------------------------------------------------------------------- // // Benchmarking @@ -781,18 +504,7 @@ inline Tensor allocate_workspace(size_t size, const Tensor &other) { return at::empty({static_cast(size)}, other.options().dtype(kByte)); } -// NOTE [ Convolution design ] -// -// cuDNN convolutions does not handle bias. Bias is handled outside. -// -// The general strategy: -// -// - cudnn_convolution (Tensor) -// Entry points for clients -// -// - cudnn_convolution_forward (TensorArg) -// Entry point, which may be reused between regular -// convolution and transposed convolution. +// NOTE [ raw_cudnn_convolution_forward_out ] // // - raw_cudnn_convolution_forward_out (Tensor) // Functiont that handles tensors that are too large to use 32bit indexing. @@ -802,14 +514,6 @@ inline Tensor allocate_workspace(size_t size, const Tensor &other) { // Low level function which invokes CuDNN, and takes an output // tensor which is directly written to (thus _out). // -// Where does argument checking happen? Here's the division of -// responsibility: -// - Things that happen in at::Tensor -// - TensorArg allocation -// - Things that happen in TensorArg -// - Check arguments (type, GPU, shape) -// -// TODO: Consider renaming zero-indexed arguments to "self" // --------------------------------------------------------------------- @@ -885,16 +589,6 @@ if (args.params.dataType == CUDNN_DATA_FLOAT) { // // --------------------------------------------------------------------- -// The raw API directly invokes CuDNN and does not emulate support -// for group convolution on old versions of CuDNN. -// -// There are a few reasons this should never be directly exposed -// via ATen: -// -// - It takes output as a parameter (this should be computed!) -// - It doesn't do input checking -// - It doesn't resize output (it is assumed to be correctly sized) -// void raw_cudnn_convolution_forward_out_32bit( const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, @@ -946,90 +640,6 @@ void raw_cudnn_convolution_forward_out( split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit); } -Tensor cudnn_convolution_forward( - CheckedFrom c, - const TensorArg& input, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - checkAllSameType(c, {input, weight}); - checkAllSameGPU(c, {input, weight}); - - auto layout = cudnn_conv_use_channels_last(*input, *weight) ? - at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - auto output_t = at::empty( - conv_output_size(input->sizes(), weight->sizes(), - padding, stride, dilation), - input->options(), - layout); - - if (output_t.numel() == 0) { - return output_t; - } - - // Avoid ambiguity of "output" when this is being used as backwards - TensorArg output{ output_t, "result", 0 }; - convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); - - // See #4500 - Tensor weight_contig = weight->contiguous(layout); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), layout); - Tensor input_contig = input->contiguous(layout); - input_contig.resize_(input_contig.sizes(), layout); - - raw_cudnn_convolution_forward_out( - *output, input_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - - return *output; -} - -Tensor cudnn_convolution( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }; - CheckedFrom c = "cudnn_convolution"; - auto output_t = cudnn_convolution_forward( - c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - return output_t; -} - -// NB: output_padding not needed here, as there is no ambiguity to -// resolve -Tensor cudnn_convolution_transpose_backward_input( - const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg grad_output { grad_output_t, "grad_output", 1 }, - weight { weight_t, "weight", 2 }; - return cudnn_convolution_forward( - "cudnn_convolution_transpose_backward_input", - grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -std::tuple cudnn_convolution_transpose_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); - - Tensor grad_input, grad_weight; - if (output_mask[0]) { - grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - if (output_mask[1]) { - grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - - return std::tuple{grad_input, grad_weight}; -} - // --------------------------------------------------------------------- // // Convolution backward / Transposed convolution forward @@ -1089,115 +699,6 @@ void raw_cudnn_convolution_backward_input_out( split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit); } -// NOTE [ Backward vs transpose convolutions ] -// -// Backward and transpose are algorithmically equivalent, but they -// compute their geometry differently. In a backwards, you knew what -// the original size of the input tensor was, so you can cache that -// geometry and fill it directly. In transposed convolution, it is -// more conventional to not explicitly specify the output (previously -// input) size, and compute it. This, however, leaves a degree of -// freedom; this degree of freedom is resolved using the -// output_padding parameter. Both of these interfaces are equivalent, -// but they are differently convenient depending on the use case. - -Tensor cudnn_convolution_backward_input( - CheckedFrom c, - IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - checkAllSameType(c, {grad_output, weight}); - checkAllSameGPU(c, {grad_output, weight}); - - auto layout = cudnn_conv_use_channels_last(*grad_output, *weight) ? - at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - auto grad_input_t = at::empty(input_size, grad_output->options(), layout); - - // Avoid "grad_input" when this is being used as transposed convolution - TensorArg grad_input{ grad_input_t, "result", 0 }; - convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); - - // See #4500 - Tensor weight_contig = weight->contiguous(layout); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), layout); - - Tensor grad_output_contig = grad_output->contiguous(layout); - grad_output_contig.resize_(grad_output_contig.sizes(), layout); - - raw_cudnn_convolution_backward_input_out( - *grad_input, grad_output_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - - return *grad_input; -} - -Tensor cudnn_convolution_transpose_forward( - CheckedFrom c, - const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), - padding, output_padding, stride, dilation, groups); - return cudnn_convolution_backward_input(c, input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -Tensor cudnn_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - weight{ weight_t, "weight", 2 }; - return cudnn_convolution_backward_input( - "cudnn_convolution_backward_input", - input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -std::tuple cudnn_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); - - Tensor grad_input, grad_weight; - if (input.numel() == 0) { - if (output_mask[0]) { - grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (output_mask[1]) { - grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - } else { - if (output_mask[0]) { - grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - if (output_mask[1]) { - grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - } - - return std::tuple{grad_input, grad_weight}; -} - -Tensor cudnn_convolution_transpose( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }; - CheckedFrom c = "cudnn_convolution_transpose"; - auto output_t = cudnn_convolution_transpose_forward( - c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - return output_t; -} - // --------------------------------------------------------------------- // // Convolution backward (weight) @@ -1295,115 +796,6 @@ void raw_cudnn_convolution_backward_weight_out( TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); } -Tensor cudnn_convolution_backward_weight( - CheckedFrom c, - IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ? - at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - - Tensor grad_output_contig_t = grad_output_t.contiguous(layout); - // Make sure that NC11 strides follow formula - grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout); - TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 }; - - Tensor input_contig_t = input_t.contiguous(layout); - input_contig_t.resize_(input_contig_t.sizes(), layout); - TensorArg input{ input_contig_t, "input", 2}; - - checkAllSameType(c, {grad_output_contig, input}); - checkAllSameGPU(c, {grad_output_contig, input}); - - auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), layout); - - // For uniformity with everything else, although it seems grad_weight - // would be unambiguous too. - TensorArg grad_weight{ grad_weight_t, "result", 0 }; - convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups); - - raw_cudnn_convolution_backward_weight_out( - *grad_weight, *grad_output_contig, *input, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - - return grad_weight_t; -} - -Tensor cudnn_convolution_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - return cudnn_convolution_backward_weight( - "cudnn_convolution_backward_weight", - weight_size, grad_output_t, input_t, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -Tensor cudnn_convolution_transpose_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - return cudnn_convolution_backward_weight( - "cudnn_convolution_backward_weight", - weight_size, input_t, grad_output_t, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - }} // namespace at::native #endif - - -namespace at { namespace native { - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_deprecated( - const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) { - auto output = at::cudnn_convolution(input, weight, padding, stride, dilation, groups, benchmark, deterministic); - if (bias.defined()) { - output = output + reshape_bias(input.dim(), bias); - } - return output; -} - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_deprecated2( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); -} - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_transpose_deprecated( - const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - auto output = at::cudnn_convolution_transpose(input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); - if (bias.defined()) { - output = output + reshape_bias(input.dim(), bias); - } - return output; -} - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_transpose_deprecated2( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); -} - -}} // namespace at::native diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp new file mode 100644 index 000000000000..53f8c37f5e64 --- /dev/null +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -0,0 +1,5 @@ +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +// Coming soon +#endif // AT_CUDNN_ENABLED and CUDNN_VERSION diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm index 8dd27aa1c3ed..9b1ff29feaa1 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm @@ -608,7 +608,9 @@ Tensor copy_to_host(const Tensor& input) { MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); auto&& sizes = [X sizes]; - MetalTensor mt{sizes}; + auto dummy = at::zeros(input.sizes()).contiguous(); + auto strides = dummy.strides(); + MetalTensor mt{sizes, strides.vec()}; mt.texture()->setCommandBuffer(commandBuffer); mt.texture()->allocateTextureStorage(sizes); MPSImage* Y = imageFromMetalTensor(mt); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index bbcbfe10fd01..cfef7c16646c 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -378,7 +378,7 @@ bool test_add() { { \ auto X1 = torch::rand(a1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \ auto X2 = torch::rand(a2, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \ - auto Y1 = at::native::add(X1, X2); \ + auto Y1 = at::add(X1, X2); \ auto MX1 = X1.metal(); \ auto MX2 = X2.metal(); \ auto Y2 = at::native::metal::mpscnn::add(MX1, MX2).cpu(); \ diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 0d4af95c7a76..92473ecc68c8 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -62,7 +62,6 @@ std::tuple miopen_batch_norm( running_mean{ running_mean_t, "running_mean", 4 }, running_var{ running_var_t, "running_var", 5 }; CheckedFrom c = "miopen_batch_norm"; - setMIOpenStreamToCurrent(); checkAllDefined(c, {input, weight, bias}); if (!training) { @@ -151,7 +150,6 @@ std::tuple miopen_batch_norm_backward( save_mean{ save_mean_t, "save_mean", 4 }, save_var{ save_var_t, "save_var", 5 }; CheckedFrom c = "miopen_batch_norm_backward"; - setMIOpenStreamToCurrent(); checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 27e119d377bc..f0b0d6fdd5b7 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -624,7 +624,6 @@ Tensor miopen_convolution( TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }, bias { bias_t, "bias", 3 }; - setMIOpenStreamToCurrent(); CheckedFrom c = "miopen_convolution"; auto output_t = miopen_convolution_forward( c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); @@ -699,7 +698,6 @@ Tensor miopen_depthwise_convolution( TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }, bias { bias_t, "bias", 3 }; - setMIOpenStreamToCurrent(); CheckedFrom c = "miopen_depthwise_convolution"; auto output_t = miopen_depthwise_convolution_forward( c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); @@ -716,7 +714,6 @@ Tensor miopen_convolution_transpose_backward_input( { TensorArg grad_output { grad_output_t, "grad_output", 1 }, weight { weight_t, "weight", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_forward( "miopen_convolution_transpose_backward_input", grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); @@ -827,7 +824,6 @@ Tensor miopen_convolution_backward_input( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, weight{ weight_t, "weight", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_backward_input( "miopen_convolution_backward_input", input_size, grad_output, weight, @@ -897,7 +893,6 @@ Tensor miopen_depthwise_convolution_backward_input( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, weight{ weight_t, "weight", 2 }; - setMIOpenStreamToCurrent(); return miopen_depthwise_convolution_backward_input( "miopen_depthwise_convolution_backward_input", input_size, grad_output, weight, @@ -1087,7 +1082,6 @@ Tensor miopen_convolution_backward_weight( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, input{ input_t, "input", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_backward_weight( "miopen_convolution_backward_weight", weight_size, grad_output, input, @@ -1103,7 +1097,6 @@ Tensor miopen_convolution_transpose_backward_weight( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, input{ input_t, "input", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_backward_weight( "miopen_convolution_backward_weight", weight_size, input, grad_output, @@ -1119,7 +1112,6 @@ Tensor miopen_depthwise_convolution_backward_weight( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, input{ input_t, "input", 2 }; - setMIOpenStreamToCurrent(); return miopen_depthwise_convolution_backward_weight( "miopen_depthwise_convolution_backward_weight", weight_size, grad_output, input, @@ -1136,7 +1128,6 @@ Tensor miopen_convolution_backward_bias( const Tensor& grad_output_t) { TensorArg grad_output{ grad_output_t, "grad_output", 1 }; - setMIOpenStreamToCurrent(); auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options()); diff --git a/aten/src/ATen/native/miopen/RNN_miopen.cpp b/aten/src/ATen/native/miopen/RNN_miopen.cpp index 1493cece3212..10b535f890ac 100644 --- a/aten/src/ATen/native/miopen/RNN_miopen.cpp +++ b/aten/src/ATen/native/miopen/RNN_miopen.cpp @@ -509,7 +509,6 @@ std::tuple miopen_rnn( size_t reserver_size; MIOPEN_CHECK(miopenGetRNNTrainingReserveSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &reserver_size)); reserve = at::empty(reserver_size, input.options().dtype(kByte)); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNForwardTraining(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), @@ -521,7 +520,6 @@ std::tuple miopen_rnn( workspace.data_ptr(), workspace_size, reserve.data_ptr(), reserver_size )); } else { //Inference. reserve = at::empty({0}, input.options().dtype(kByte)); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNForwardInference(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), @@ -630,7 +628,6 @@ std::tuple miopen_rnn_backward_input( )); auto workspace = at::empty(workspace_size, input.options().dtype(kByte)); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNBackwardData( handle, descs.rnn_desc.desc(), @@ -715,7 +712,6 @@ std::vector miopen_rnn_backward_weight( auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNBackwardWeights( handle, descs.rnn_desc.desc(), diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1a9650ccfc25..768ddf2fc17d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -230,6 +230,7 @@ DefaultBackend: abs_ - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: abs_out @@ -367,6 +368,7 @@ - func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor use_c10_dispatcher: full + structured_delegate: add.out variants: function, method dispatch: CPU, CUDA: add @@ -376,12 +378,15 @@ - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) use_c10_dispatcher: full variants: method + structured_delegate: add.out dispatch: CPU, CUDA: add_ SparseCPU, SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: add_out SparseCPU: add_out_sparse_cpu @@ -1115,6 +1120,7 @@ - func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) use_c10_dispatcher: full variants: method + manual_cpp_binding: True - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -1481,6 +1487,7 @@ SparseCPU, SparseCUDA: div_sparse_ - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: full dispatch: CPU, CUDA: div_out SparseCPU, SparseCUDA: div_out_sparse_zerodim @@ -3554,8 +3561,9 @@ - func: size.int(Tensor self, int dim) -> int use_c10_dispatcher: full - variants: function, method + variants: function device_guard: False + manual_cpp_binding: True - func: size.Dimname(Tensor self, Dimname dim) -> int variants: function, method @@ -3717,8 +3725,9 @@ - func: stride.int(Tensor self, int dim) -> int use_c10_dispatcher: full - variants: function, method + variants: function device_guard: False + manual_cpp_binding: True - func: stride.Dimname(Tensor self, Dimname dim) -> int variants: function, method @@ -6238,15 +6247,13 @@ - func: eig.e(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) dispatch: - CPU: legacy::cpu::_th_eig_out - CUDA: eig_cuda_out + DefaultBackend: eig_out - func: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors) use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_eig - CUDA: eig_cuda + DefaultBackend: eig - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) dispatch: @@ -7623,6 +7630,54 @@ CPU: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ +- func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[] + use_c10_dispatcher: full + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_reciprocal_slow + CUDA: foreach_tensor_reciprocal_cuda + +- func: _foreach_reciprocal_(Tensor(a!)[] self) -> () + use_c10_dispatcher: full + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_reciprocal_slow_ + CUDA: foreach_tensor_reciprocal_cuda_ + +- func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[] + use_c10_dispatcher: full + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_sigmoid_slow + CUDA: foreach_tensor_sigmoid_cuda + +- func: _foreach_sigmoid_(Tensor(a!)[] self) -> () + use_c10_dispatcher: full + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_sigmoid_slow_ + CUDA: foreach_tensor_sigmoid_cuda_ + +- func: _foreach_trunc(Tensor[] tensors) -> Tensor[] + use_c10_dispatcher: full + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_trunc_slow + CUDA: foreach_tensor_trunc_cuda + +- func: _foreach_trunc_(Tensor(a!)[] self) -> () + use_c10_dispatcher: full + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_trunc_slow_ + CUDA: foreach_tensor_trunc_cuda_ + - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () use_c10_dispatcher: full device_guard: False diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index a2349790d117..b4cff64b309d 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -314,12 +314,16 @@ struct CAFFE2_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { int64_t bit_rate, c10::QScheme q_scheme, int64_t version) - : packed_w(std::move(packed_w)), - w_scale(std::move(w_scale)), - w_zp(std::move(w_zp)), - bit_rate_(bit_rate), - q_scheme(q_scheme), - version_(version) {} + : packed_w(std::move(packed_w)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + bit_rate_(bit_rate), + q_scheme(q_scheme), + version_(version) { + if (!packed_w.is_contiguous()) { + packed_w = packed_w.contiguous(); + } + } at::Tensor packed_w; std::vector w_scale; diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 1c52242641e7..c4e92dd039e2 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -45,11 +45,12 @@ at::Tensor embedding_bag_4bit_impl( } } - const int64_t N = weight.size(0); - const int64_t weight_size = weight.size(1); + const auto weight_sizes = weight.sizes(); + const int64_t N = weight_sizes[0]; + const int64_t weight_size = weight_sizes[1]; const int64_t D = (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset - const int64_t M = offsets.size(0); + const int64_t M = offsets.sizes()[0]; int64_t output_size = M - 1; std::vector offsets_include_last_val; @@ -231,9 +232,10 @@ at::Tensor embedding_bag_byte_impl( } } - const int64_t N = weight.size(0); - const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias - const int64_t M = offsets.size(0); + const auto weight_sizes = weight.sizes(); + const int64_t N = weight_sizes[0]; + const int64_t D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias + const int64_t M = offsets.sizes()[0]; int64_t output_size = M - 1; std::vector offsets_include_last_val; @@ -254,7 +256,8 @@ at::Tensor embedding_bag_byte_impl( } std::vector shape; if (indices.dim() == 2 && is_embedding_op) { - shape = {indices.size(0), indices.size(1), D}; + const auto indices_sizes = indices.sizes(); + shape = {indices_sizes[0], indices_sizes[1], D}; } else { shape = {output_size, D}; } @@ -350,8 +353,8 @@ at::Tensor embedding_bag_byte_helper( !offsets_in.has_value(), "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); - offsets = - at::arange(0, indices.numel(), indices.size(1), indices.scalar_type()); + offsets = at::arange( + 0, indices.numel(), indices.sizes()[1], indices.scalar_type()); } else { TORCH_CHECK( offsets_in.has_value(), @@ -369,12 +372,16 @@ at::Tensor embedding_bag_byte_helper( "Expect 32 or 64 bit offsets, but found ", offsets.scalar_type(), " instead."); + TORCH_CHECK( + weight.is_contiguous() && indices.is_contiguous() && + offsets.is_contiguous(), + "Expect weight, indices, and offsets to be contiguous."); // Using helper function to support different type combination without the // need to cast, which can be additional performance overhead if (indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kInt) { return embedding_bag_byte_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -385,7 +392,7 @@ at::Tensor embedding_bag_byte_helper( } else if ( indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) { return embedding_bag_byte_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -396,7 +403,7 @@ at::Tensor embedding_bag_byte_helper( } else if ( indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) { return embedding_bag_byte_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -408,7 +415,7 @@ at::Tensor embedding_bag_byte_helper( // default case given the TORCH_CHECK above return embedding_bag_byte_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -439,8 +446,8 @@ at::Tensor embedding_bag_4bit_helper( !offsets_in.has_value(), "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); - offsets = - at::arange(0, indices.numel(), indices.size(1), indices.scalar_type()); + offsets = at::arange( + 0, indices.numel(), indices.sizes()[1], indices.scalar_type()); } else { TORCH_CHECK( offsets_in.has_value(), @@ -458,12 +465,16 @@ at::Tensor embedding_bag_4bit_helper( "Expect 32 or 64 bit offsets, but found ", offsets.scalar_type(), " instead."); + TORCH_CHECK( + weight.is_contiguous() && indices.is_contiguous() && + offsets.is_contiguous(), + "Expect weight, indices, and offsets to be contiguous."); // Using helper function to support different type combination without the // need to cast, which can be additional performance overhead if (indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kInt) { return embedding_bag_4bit_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -473,7 +484,7 @@ at::Tensor embedding_bag_4bit_helper( } else if ( indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) { return embedding_bag_4bit_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -483,7 +494,7 @@ at::Tensor embedding_bag_4bit_helper( } else if ( indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) { return embedding_bag_4bit_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -492,7 +503,7 @@ at::Tensor embedding_bag_4bit_helper( include_last_offset); } return embedding_bag_4bit_impl( - weight.contiguous(), + weight, indices, offsets, pruned_weights, @@ -511,7 +522,7 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( bool include_last_offset, bool is_embedding_op) { return embedding_bag_byte_helper( - packed_w.contiguous(), + packed_w, indices, offsets_in, pruned_weights, @@ -538,7 +549,7 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( } return embedding_bag_4bit_helper( - packed_w.contiguous(), + packed_w, indices, offsets_in, pruned_weights, @@ -564,7 +575,7 @@ Tensor embedding_bag_byte_rowwise_offsets( const c10::optional& compressed_indices_mapping, bool include_last_offset) { return embedding_bag_byte_helper( - weight.contiguous(), + weight, indices, offsets_in, pruned_weights, @@ -594,7 +605,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( } return embedding_bag_4bit_helper( - weight.contiguous(), + weight, indices, offsets_in, pruned_weights, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt index 9f8cb6d9ed09..99bf8ba07074 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt @@ -271,7 +271,7 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") set_property(SOURCE ${PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") endif() endif() -if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR IOS_ARCH MATCHES "^arm64.*") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$" OR IOS_ARCH MATCHES "^arm64.*") set_property(SOURCE ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") if(IOS) set_property(SOURCE ${PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 91a275a6aecf..2c8a6d4e4946 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -132,7 +132,6 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor")); diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 3c836d0258d1..2bee0a581366 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -241,7 +241,7 @@ static SparseTensor& coalesce_(SparseTensor& tensor) { // values=[1., 1.] (after truncation), which sum to 2.f instead of 3.f. // To perform floor division the sparse tensor must be coalesced first. -SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) { +SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, SparseTensor& r) { TORCH_CHECK(value.dim() == 0, "Sparse division requires a scalar or ", "zero-dim dense tensor divisor (got shape ", value.sizes(), " for divisor)"); TORCH_CHECK(!value.is_sparse(), "Sparse division requires a scalar or ", @@ -279,15 +279,15 @@ Tensor div_sparse(const Tensor& self, const Tensor& value) { commonDtype = typeMetaToScalarType(at::get_default_dtype()); } Tensor result = at::empty({0}, self.options().dtype(commonDtype)); - return div_out_sparse_zerodim(result, self, value); + return div_out_sparse_zerodim(self, value, result); } Tensor& div_sparse_(Tensor& self, const Tensor& value) { - return div_out_sparse_zerodim(self, self, value); + return div_out_sparse_zerodim(self, value, self); } SparseTensor& div_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) { - return div_out_sparse_zerodim(r, t, wrapped_scalar_tensor(value)); + return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); } // -------------------------------------------------------------------- @@ -1108,7 +1108,7 @@ SparseTensor& _sspaddmm_out_cpu( "sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1)); int64_t nnz = sparse._nnz(); - // We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage + // We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage LongTensor indices = sparse._indices().contiguous(); Tensor values = sparse._values(); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 5d25138500d7..660862181262 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -96,18 +96,16 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { dim3 block(C10_WARP_SIZE, SZ); AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, values.scalar_type(), "coalesce_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "coalesce_sparse_cuda", [&] { - using cuda_accscalar_t = acc_type; - apply::coalesceValuesKernel<<>>( - uniqueOffsets.data_ptr(), - origIndices.data_ptr(), - values.data_ptr(), - newValues.data_ptr(), - nnz, - newNnz, - stride - ); - }); + using cuda_accscalar_t = acc_type; + apply::coalesceValuesKernel<<>>( + uniqueOffsets.data_ptr(), + origIndices.data_ptr(), + values.data_ptr(), + newValues.data_ptr(), + nnz, + newNnz, + stride + ); }); } diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 81058ec266f2..d0aafe680efb 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -340,13 +340,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } else { TORCH_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); @@ -356,13 +354,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernel, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernel, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } } else { @@ -373,11 +369,9 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT // NB: Purposely not inplace! AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - values = values.mul(value); - } - }); + if (value.to() != static_cast(1)) { + values = values.mul(value); + } }); int64_t view_rows = 1; @@ -445,11 +439,9 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - s_values_ = s_values_.mul(value); - } - }); + if (value.to() != static_cast(1)) { + s_values_ = s_values_.mul(value); + } }); LongTensor r_indices_ = at::cat({t_indices_, s_indices_}, 1); Tensor r_values_ = at::cat({t_values_, s_values_}, 0); diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index a77b1935eda6..5af2c14b80cb 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -119,17 +119,40 @@ vTensor pack_weights( } // shader KO4C4HW_to_image - float image[4 * C_4][OC_4][KH * KW][4]; - memset(image, 0.f, 16 * C_4 * OC_4 * KH * KW * sizeof(float)); + struct Image3D { + float* data_; + uint32_t dim0_, dim1_, dim2_; + + Image3D(uint32_t dim0, uint32_t dim1, uint32_t dim2) { + dim0_ = dim0; + dim1_ = dim1; + dim2_ = dim2; + data_ = new float[dim0 * dim1 * dim2 * 4]; + memset(data_, 0.f, dim0 * dim1 * dim2 * 4 * sizeof(float)); + } + + inline uint32_t idx(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) { + return i3 + i2 * 4 + i1 * 4 * dim2_ + i0 * 4 * dim2_ * dim1_; + } + + void set(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, float value) { + data_[idx(i0, i1, i2, i3)] = value; + } + + float get(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) { + return data_[idx(i0, i1, i2, i3)]; + } + } image{4 * C_4, OC_4, KH * KW}; + for (uint32_t sx = 0; sx < C_4; ++sx) { for (uint32_t sy = 0; sy < OC_4; ++sy) { for (uint32_t sz = 0; sz < (KH * KW); ++sz) { for (uint32_t vi = 0; vi < 4; ++vi) { int bufferVIdx = 4 * sx * KH * KW + 4 * sy * C_4 * KH * KW + 4 * sz; - image[4 * sx + 0][sy][sz][vi] = dst[4 * (bufferVIdx + 0) + vi]; - image[4 * sx + 1][sy][sz][vi] = dst[4 * (bufferVIdx + 1) + vi]; - image[4 * sx + 2][sy][sz][vi] = dst[4 * (bufferVIdx + 2) + vi]; - image[4 * sx + 3][sy][sz][vi] = dst[4 * (bufferVIdx + 3) + vi]; + image.set(4 * sx + 0, sy, sz, vi, dst[4 * (bufferVIdx + 0) + vi]); + image.set(4 * sx + 1, sy, sz, vi, dst[4 * (bufferVIdx + 1) + vi]); + image.set(4 * sx + 2, sy, sz, vi, dst[4 * (bufferVIdx + 2) + vi]); + image.set(4 * sx + 3, sy, sz, vi, dst[4 * (bufferVIdx + 3) + vi]); } } } @@ -143,7 +166,7 @@ vTensor pack_weights( for (uint32_t sy = 0; sy < H; ++sy) { for (uint32_t sz = 0; sz < D; ++sz) { for (uint32_t szvi = 0; szvi < 4; ++szvi) { - dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image[sx][sy][sz][szvi]; + dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image.get(sx, sy, sz, szvi); } } } diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 102931fd4aa7..d1b0acb87c28 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -30,8 +30,6 @@ std::atomic defaultNodeId(-1); std::atomic next_thread_id_ {0}; thread_local uint64_t current_thread_id_ = 0; -thread_local bool tls_record_function_enabled_ = true; - // Low probability constant static const double kLowProb = 0.001; struct CoinflipTLS { @@ -68,6 +66,10 @@ void set_record_function_tls_(const RecordFunctionTLS& tls) { class CallbackManager { public: CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) { + if (cb.samplingProb() > kLowProb) { + // pre-sampling of RecordFunction with prob. kLowProb cannot be used + at::bumpRecordAllFunctions(); + } // note: monotonically increasing callbacks_unique_id keeps // sorted_tls_callbacks_ sorted auto handle = next_unique_callback_handle(); @@ -76,6 +78,10 @@ class CallbackManager { } CallbackHandle addGlobalCallback(RecordFunctionCallback cb) { + if (cb.samplingProb() > kLowProb) { + // pre-sampling of RecordFunction with prob. kLowProb cannot be used + at::bumpRecordAllFunctions(); + } auto handle = next_unique_callback_handle(); sorted_global_callbacks_.emplace_back(std::move(cb), handle); return handle; @@ -92,6 +98,10 @@ class CallbackManager { return el.second == handle; }); if (it != cbs.end()) { + if (it->first.samplingProb() > kLowProb) { + // try to restore pre-sampling of RecordFunction + at::releaseRecordAllFunctions(); + } // keeps it sorted cbs.erase(it); return true; @@ -127,7 +137,13 @@ class CallbackManager { // callbackShouldRun is even hotter because it's called multiple // times per init(). Profiling shows that the function prologue is // taking up a significant fraction of the time. - static bool C10_ALWAYS_INLINE callbackShouldRun(const RecordFunctionCallback& cb, RecordScope scope) { + static bool C10_ALWAYS_INLINE callbackShouldRun( + const RecordFunctionCallback& cb, RecordScope scope, bool pre_sampled) { + TORCH_INTERNAL_ASSERT( + !pre_sampled || (cb.sampling_prob_ <= kLowProb), + "Incorrect usage of a pre-sampled RecordFunction with a high-frequency " + " or non-sampled callback"); + // first check whether this callback is interested in // the given scope type if (!cb.checkScope(scope)) { @@ -138,36 +154,45 @@ class CallbackManager { return cb.should_run_(cb); } - if (cb.sampling_prob_ == 1.0) { - return true; + // otherwise potentially do the sampling + double sampling_prob = cb.sampling_prob_; + if (pre_sampled) { + // adjust the sampling rate to account for kLowProb pre-sampling of + // the RecordFunction + sampling_prob /= kLowProb; } - // model the low probability events as events happening - // with probability kLowProb followed by another sampling with - // probability (sampling_prob__ / kLowProb), then replace the coin - // flip for kLowProb with a thread local number of tries tries_left_ - // sampled from the geometric distribution. - if (cb.sampling_prob_ < kLowProb) { - if (coinflip_tls_.tries_left_ == 0) { - coinflip_tls_.tries_left_ = sample_geometric(); - return (sample_zero_one() < cb.sampling_prob_ / kLowProb); + + if (sampling_prob < 1.0) { + // model the low probability events as events happening + // with probability kLowProb followed by another sampling with + // probability (sampling_prob / kLowProb), then replace the coin + // flip for kLowProb with a thread local number of tries tries_left_ + // sampled from the geometric distribution. + if (sampling_prob < kLowProb) { + if (coinflip_tls_.tries_left_ == 0) { + coinflip_tls_.tries_left_ = sample_geometric(); + return (sample_zero_one() < sampling_prob / kLowProb); + } else { + --coinflip_tls_.tries_left_; + return false; + } } else { - --coinflip_tls_.tries_left_; - return false; + return (sample_zero_one() < sampling_prob); } - } else { - return (sample_zero_one() < cb.sampling_prob_); } + + return true; } // init is called by RecordFunction in constructor to // determine which thread local and global callbacks are going // to be executed and whether any of them need inputs - inline void init(RecordFunction& rec_fn, RecordScope scope) { + inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) { bool found_needs_inputs = false; bool found_needs_ids = false; for (const auto& cb: rf_tls_.sorted_tls_callbacks_) { - if (callbackShouldRun(cb.first, scope)) { + if (callbackShouldRun(cb.first, scope, pre_sampled)) { if (cb.first.needsInputs()) { found_needs_inputs = true; } @@ -182,7 +207,7 @@ class CallbackManager { } for (const auto& cb: sorted_global_callbacks_) { - if (callbackShouldRun(cb.first, scope)) { + if (callbackShouldRun(cb.first, scope, pre_sampled)) { if (cb.first.needsInputs()) { found_needs_inputs = true; } @@ -308,7 +333,6 @@ namespace { } } // namespace - RecordFunctionCallbacks _getTLSCallbacks() { return rf_tls_.sorted_tls_callbacks_; } @@ -374,12 +398,12 @@ void enableRecordFunction(bool enable) { rf_tls_.tls_record_function_enabled_ = enable; } -RecordFunction::RecordFunction(RecordScope scope) { +RecordFunction::RecordFunction(RecordScope scope, bool pre_sampled) { auto* rf_tls_ptr = &rf_tls_; if (rf_tls_ptr->tls_record_function_enabled_) { auto& m = manager(); if (!m.sorted_global_callbacks_.empty() || !rf_tls_ptr->sorted_tls_callbacks_.empty()) { - m.init(*this, scope); + m.init(*this, scope, pre_sampled); } } } @@ -451,4 +475,49 @@ void RecordFunction::end() { } } +// RecordFunction pre-sampling +namespace { +// Whether to try to create RecordFunction on each call (>0) or +// use pre-sampling (=0) +std::atomic global_record_all_functions_ {0}; +} + +void bumpRecordAllFunctions() { + global_record_all_functions_.fetch_add(1, std::memory_order_relaxed); +} + +void releaseRecordAllFunctions() { + TORCH_CHECK(global_record_all_functions_.fetch_sub(1, std::memory_order_relaxed) >= 0); +} + +bool checkRecordAllFunctions() { + return (global_record_all_functions_.load(std::memory_order_relaxed) > 0); +} + +bool shouldRunRecordFunction(bool* pre_sampled) { + auto* rf_tls_ptr = &rf_tls_; + if (rf_tls_ptr->sorted_tls_callbacks_.empty() && !manager().hasGlobalCallbacks()) { + *pre_sampled = false; + return false; + } + if (global_record_all_functions_.load(std::memory_order_relaxed) > 0) { + *pre_sampled = false; + return true; + } + if (!rf_tls_ptr->tls_record_function_enabled_) { + *pre_sampled = false; + return false; + } + + *pre_sampled = true; + auto* coinflip_tls_ptr = &coinflip_tls_; + if (coinflip_tls_ptr->tries_left_ == 0) { + coinflip_tls_ptr->tries_left_ = sample_geometric(); + return true; + } else { + --coinflip_tls_ptr->tries_left_; + return false; + } +} + } // namespace at diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 4b07d13aa747..6b2e08576068 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -23,6 +23,8 @@ enum class C10_API_ENUM RecordScope : uint8_t { BACKWARD_FUNCTION, // TorchScript functions, methods TORCHSCRIPT_FUNCTION, + // Kernel Function dtype Tag + KERNEL_FUNCTION_DTYPE, // User defined scope (e.g. with record_function()) USER_SCOPE, NUM_SCOPES, // must be the last in the list @@ -90,8 +92,11 @@ typedef uint64_t RecordFunctionHandle; struct TORCH_API RecordFunction { // Default constructor is used with before function called afterwards: // scope - record scope that this function tracks + // pre_sampled - whether this RecordFunction was already pre-sampled with + // kLowProb probability RecordFunction( - RecordScope scope = RecordScope::FUNCTION); + RecordScope scope = RecordScope::FUNCTION, + bool pre_sampled = false); template void before( @@ -238,6 +243,9 @@ struct TORCH_API RecordFunction { // flag is used to check whether the start callbacks were called bool called_start_callbacks_ = false; + // Whether the RecordFunction is pre-sampled + bool pre_sampled_ = false; + // Used internally to keep track of thread local and global callbacks // that were picked to run; must be sorted; CallbackHandles sorted_active_tls_handles_; @@ -308,17 +316,6 @@ class TORCH_API RecordFunctionCallback { scopes_.fill(true); } - // This interface is for observers that do not pass an ObserverContext object - // between start and end callbacks. - explicit RecordFunctionCallback( - std::function start, - std::function end = - [](const RecordFunction&) {}): - start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }}, - end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} { - scopes_.fill(true); - } - RecordFunctionCallback& needsInputs(bool needs_inputs) { needs_inputs_ = needs_inputs; return *this; @@ -330,7 +327,7 @@ class TORCH_API RecordFunctionCallback { } RecordFunctionCallback& samplingProb(double sampling_prob) { - TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob_ <= 1.0, + TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob <= 1.0, "Invalid sampling probability"); sampling_prob_ = sampling_prob; return *this; @@ -544,10 +541,27 @@ struct TORCH_API RecordFunctionTLS { RecordFunctionCallbacks sorted_tls_callbacks_; bool tls_record_function_enabled_ = true; + + // Stores the number of coin flips before the next successful coin flip + int tries_left_ = 0; }; TORCH_API const RecordFunctionTLS& get_record_function_tls_(); TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); +// Checks whether RecordFunction should be called, +// sets boolean pointed by the argument to whether pre-sampling was used +TORCH_API bool shouldRunRecordFunction(bool*); + +// The following functions are used to disable/enable pre-sampling of RecordFunction +// when high-frequency/non-sampled callbacks are added/removed. +// Note: every call to bumpRecordAllFunctions() is supposed to be matched with +// the corresponding releaseRecordAllFunctions() call. +// Note: disabling pre-sampling of RecordFunction incurs an extra overhead, since +// RecordFunction will be created for each operator call. +TORCH_API void bumpRecordAllFunctions(); +TORCH_API void releaseRecordAllFunctions(); +TORCH_API bool checkRecordAllFunctions(); + } // namespace at diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index 50623dc2dfed..8f5e35d3ea73 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -134,4 +134,12 @@ inline int64_t numel(const Tensor& tensor) { return tensor.numel(); } +inline int64_t size(const Tensor& tensor, int64_t dim) { + return tensor.size(dim); +} + +inline int64_t stride(const Tensor& tensor, int64_t dim) { + return tensor.stride(dim); +} + } diff --git a/aten/src/ATen/templates/MetaFunctions.h b/aten/src/ATen/templates/MetaFunctions.h index d0489d1964f3..7ad20b734330 100644 --- a/aten/src/ATen/templates/MetaFunctions.h +++ b/aten/src/ATen/templates/MetaFunctions.h @@ -6,6 +6,7 @@ #include namespace at { + namespace meta { ${declarations} diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 31005b04308f..e923f6d73bd0 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -20,6 +20,9 @@ #include #include #include +#include +#include +#include #include #include diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index e79957e5ca6e..75f614bb6ea8 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -115,6 +116,26 @@ class CAFFE2_API Tensor { return impl_->storage_offset(); } + Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + if (is_contiguous(memory_format)) { + return *this; + } else { + return __dispatch_contiguous(memory_format); + } + } + + int64_t size(int64_t dim) const { + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + dim = c10::maybe_wrap_dim(dim, this->dim(), false); + return sizes()[dim]; + } + + int64_t stride(int64_t dim) const { + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + dim = c10::maybe_wrap_dim(dim, this->dim(), false); + return strides()[dim]; + } + TensorImpl * unsafeGetTensorImpl() const { return impl_.get(); } diff --git a/aten/src/ATen/test/cuda_atomic_ops_test.cu b/aten/src/ATen/test/cuda_atomic_ops_test.cu index 285623349e52..920a72452916 100644 --- a/aten/src/ATen/test/cuda_atomic_ops_test.cu +++ b/aten/src/ATen/test/cuda_atomic_ops_test.cu @@ -11,7 +11,7 @@ template __global__ void addition_test_kernel(T * a, T * sum) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int idx = (tid) % arraysize; - + gpuAtomicAdd(&sum[idx], a[idx]); } @@ -19,7 +19,7 @@ template __global__ void mul_test_kernel(T * a, T * sum) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int idx = (tid) % arraysize; - + gpuAtomicMul(&sum[idx], a[idx]); } @@ -29,7 +29,7 @@ void test_atomic_add() { dim3 dimGrid(1, 1); T *a, *sum, *answer, *ad, *sumd; - + a = (T*)malloc(arraysize * sizeof(T)); sum = (T*)malloc(arraysize * sizeof(T)); answer = (T*)malloc(arraysize * sizeof(T)); @@ -42,7 +42,7 @@ void test_atomic_add() { cudaMalloc((void**)&ad, arraysize * sizeof(T)); cudaMalloc((void**)&sumd, arraysize * sizeof(T)); - + cudaMemcpy(ad, a, arraysize * sizeof(T), cudaMemcpyHostToDevice); cudaMemcpy(sumd, sum, arraysize * sizeof(T), cudaMemcpyHostToDevice); @@ -67,7 +67,7 @@ void test_atomic_mul() { dim3 dimGrid(1, 1); T *a, *sum, *answer, *ad, *sumd; - + a = (T*)malloc(arraysize * sizeof(T)); sum = (T*)malloc(arraysize * sizeof(T)); answer = (T*)malloc(arraysize * sizeof(T)); @@ -75,12 +75,12 @@ void test_atomic_mul() { for (int i = 0; i < arraysize; ++i) { a[i] = 2; sum[i] = 2; - answer[i] = pow(sum[i], factor); + answer[i] = pow(sum[i], static_cast(factor)); } cudaMalloc((void**)&ad, arraysize * sizeof(T)); cudaMalloc((void**)&sumd, arraysize * sizeof(T)); - + cudaMemcpy(ad, a, arraysize * sizeof(T), cudaMemcpyHostToDevice); cudaMemcpy(sumd, sum, arraysize * sizeof(T), cudaMemcpyHostToDevice); @@ -105,7 +105,7 @@ TEST(TestAtomicOps, TestAtomicAdd) { test_atomic_add(); test_atomic_add(); test_atomic_add(); - + test_atomic_add(); test_atomic_add(); test_atomic_add(); diff --git a/aten/src/ATen/test/vulkan_test.cpp b/aten/src/ATen/test/vulkan_test.cpp index 7c4e96f7f1a6..6b066b4337be 100644 --- a/aten/src/ATen/test/vulkan_test.cpp +++ b/aten/src/ATen/test/vulkan_test.cpp @@ -927,10 +927,10 @@ TEST(VulkanTest, avg_pool2d) { auto t_in = at::rand({1, 3, 7, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto t_out_expected = at::avg_pool2d(t_in, {2, 2}, {1}, {0}, {1}); + auto t_out_expected = at::avg_pool2d(t_in, {2, 2}, {1}, {0}, true); auto tv_in = t_in.vulkan(); - auto tv_out = at::avg_pool2d(tv_in, {2, 2}, {1}, {0}, {1}); + auto tv_out = at::avg_pool2d(tv_in, {2, 2}, {1}, {0}, true); auto t_out = tv_out.cpu(); const auto check = almostEqual(t_out, t_out_expected); diff --git a/aten/src/TH/THAllocator.cpp b/aten/src/TH/THAllocator.cpp index 55b42d2f9d27..53b67a17032f 100644 --- a/aten/src/TH/THAllocator.cpp +++ b/aten/src/TH/THAllocator.cpp @@ -6,6 +6,7 @@ #endif #include +#include /* stuff for mapped files */ #ifdef _WIN32 @@ -74,24 +75,26 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, #ifdef _WIN32 if (flags_ & TH_ALLOCATOR_MAPPED_SHAREDMEM) { // Shadowing - const char *filename; - const char *eventname; + const wchar_t *filename; + const wchar_t *eventname; + const std::wstring wFilename = c10::u8u16(filename_); + const std::wstring wEventname = c10::u8u16(eventname_); LARGE_INTEGER hfilesz; if (filename_[0] == '/') { - filename = filename_.c_str() + 1; - eventname = eventname_.c_str() + 1; + filename = wFilename.c_str() + 1; + eventname = wEventname.c_str() + 1; } else { - filename = filename_.c_str(); - eventname = eventname_.c_str(); + filename = wFilename.c_str(); + eventname = wEventname.c_str(); } hfilesz.QuadPart = size; if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) { - event_ = CreateEvent(nullptr, FALSE, FALSE, eventname); + event_ = CreateEventW(nullptr, FALSE, FALSE, eventname); } else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) { - event_ = OpenEvent(EVENT_ALL_ACCESS, FALSE, eventname); + event_ = OpenEventW(EVENT_ALL_ACCESS, FALSE, eventname); } else { AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE"); } @@ -101,9 +104,9 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, } if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) { - handle_ = CreateFileMapping(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename); + handle_ = CreateFileMappingW(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename); } else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) { - handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, filename); + handle_ = OpenFileMappingW(FILE_MAP_ALL_ACCESS, FALSE, filename); } else { AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE"); } @@ -136,15 +139,21 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, AT_ERROR("TH_ALLOCATOR_MAPPED_FROMFD not supported on Windows"); } + // Shadowing + const wchar_t *filename; + const std::wstring wFilename = c10::u8u16(filename_); + + filename = wFilename.c_str(); + /* open file */ /* FILE_FLAG_RANDOM_ACCESS ? */ if (flags_) { - hfile = CreateFileA(filename_.c_str(), GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); + hfile = CreateFileW(filename, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); if (hfile == INVALID_HANDLE_VALUE) { AT_ERROR("could not open file <", filename_, "> in read-write mode; error code: <", GetLastError(), ">"); } } else { - hfile = CreateFileA(filename_.c_str(), GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + hfile = CreateFileW(filename, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); if (hfile == INVALID_HANDLE_VALUE) { AT_ERROR("could not open file <", filename_, "> in read-only mode; error code: <", GetLastError(), ">"); } @@ -181,11 +190,11 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, /* get map handle */ if (flags_) { - if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { + if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">"); } } else { - if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { + if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">"); } } diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp index c0fb51f53e45..5b4ef15e7c2c 100644 --- a/aten/src/TH/generic/THLapack.cpp +++ b/aten/src/TH/generic/THLapack.cpp @@ -4,8 +4,6 @@ TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); -TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); -TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info); TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info); TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info); @@ -31,21 +29,6 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, s #endif } -/* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and, -optionally, the left and/or right eigenvectors */ -void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info) -{ -#ifdef USE_LAPACK -#if defined(TH_REAL_IS_DOUBLE) - dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); -#else - sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); -#endif -#else - THError("geev : Lapack library not found in compile time\n"); -#endif -} - /* Cholesky factorization based Matrix Inverse */ void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info) { diff --git a/aten/src/TH/generic/THLapack.h b/aten/src/TH/generic/THLapack.h index 287915c74d26..121eee871c67 100644 --- a/aten/src/TH/generic/THLapack.h +++ b/aten/src/TH/generic/THLapack.h @@ -4,8 +4,6 @@ /* ||AX-B|| */ TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info); -/* Non-sym eigenvals */ -TH_API void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); /* Positive Definite matrices */ /* Matrix inverse based on Cholesky factorization */ diff --git a/aten/src/TH/generic/THStorage.cpp b/aten/src/TH/generic/THStorage.cpp index 2db795719557..a085f31c740f 100644 --- a/aten/src/TH/generic/THStorage.cpp +++ b/aten/src/TH/generic/THStorage.cpp @@ -115,10 +115,9 @@ void THStorage_(resizeBytes)(THStorage* storage, ptrdiff_t size_bytes) { void THStorage_(fill)(THStorage *storage, scalar_t value) { - ptrdiff_t i; auto type_meta = caffe2::TypeMeta::Make(); size_t numel = storage->nbytes() / type_meta.itemsize(); - for (i = 0; i < numel; i++) + for (size_t i = 0; i < numel; i++) THStorage_(data)(storage)[i] = value; } diff --git a/aten/src/TH/generic/THStorageCopy.cpp b/aten/src/TH/generic/THStorageCopy.cpp index dc19deea7652..2d6ec8a05eb6 100644 --- a/aten/src/TH/generic/THStorageCopy.cpp +++ b/aten/src/TH/generic/THStorageCopy.cpp @@ -8,7 +8,7 @@ void THStorage_(copy)(THStorage *storage, THStorage *src) scalar_t *scalar_src = THStorage_(data)(src); scalar_t *data = THStorage_(data)(storage); uint64_t numel = storage->nbytes() / sizeof(scalar_t); - for (ptrdiff_t i = 0; i < numel; ++i) { + for (uint64_t i = 0; i < numel; ++i) { data[i] = scalar_src[i]; } } @@ -19,11 +19,10 @@ void THStorage_(copy)(THStorage *storage, THStorage *src) #define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \ void THStorage_(copy##TYPENAMESRC)( \ THStorage * storage, TH##TYPENAMESRC##Storage * src) { \ - ptrdiff_t i; \ auto data = THStorage_(data)(storage); \ auto src_data = TH##TYPENAMESRC##Storage_data(src); \ uint64_t numel = storage->nbytes() / sizeof(scalar_t); \ - for (i = 0; i < numel; i++) \ + for (uint64_t i = 0; i < numel; i++) \ data[i] = static_cast(src_data[i]); \ } diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 6a79f3e14c14..9c1eb3cdfe22 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -5,6 +5,7 @@ #include #include #include +#include // Finds non-zero elements of a tensor and returns their subscripts void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) @@ -254,6 +255,13 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar numel = THLongTensor_nElement(index); THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector"); THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(tensor), 4,"Indexing dim %d is out of bounds of tensor", dim); + at::assert_no_overlap(tensor, index); + if (at::has_internal_overlap(tensor) == at::MemOverlap::YES) { + TORCH_WARN( + "Use of index_fill_ on expanded tensors is deprecated. " + "Please clone() the tensor before performing this operation. " + "This also applies to advanced indexing e.g. tensor[mask] = scalar"); + } index = THLongTensor_newContiguous(index); index_data = THLongTensor_data(index); diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 76d7d7bc48d8..e6c200169191 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -191,88 +191,6 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) if (free_b) c10::raw::intrusive_ptr::decref(b); } -void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, bool eigenvectors) -{ - char jobvr = eigenvectors ? 'V' : 'N'; - int n, lda, lwork, info, ldvr; - THTensor *work=nullptr, *wi, *wr, *a; - scalar_t wkopt; - scalar_t *rv_data; - int64_t i; - - THTensor *re__ = NULL; - THTensor *rv__ = NULL; - - THArgCheck(a_->dim() == 2, 1, "A should be 2 dimensional"); - THArgCheck(a_->size(0) == a_->size(1), 1,"A should be square"); - THArgCheck(THTensor_(isFinite)(a_), 1, "A should not contain infs or NaNs"); - - /* we want to definitely clone a_ for geev*/ - a = THTensor_(cloneColumnMajor)(NULL, a_); - - n = a->size(0); - lda = n; - - wi = THTensor_(newWithSize1d)(n); - wr = THTensor_(newWithSize1d)(n); - - rv_data = NULL; - ldvr = 1; - if (jobvr == 'V') - { - THTensor_(resize2d)(rv_,n,n); - /* guard against someone passing a correct size, but wrong stride */ - rv__ = THTensor_(newTransposedContiguous)(rv_); - rv_data = rv__->data(); - ldvr = n; - } - THTensor_(resize2d)(re_,n,2); - re__ = THTensor_(newContiguous)(re_); - - if (n > 0) { // lapack doesn't work with size 0 - /* get optimal workspace size */ - THLapack_(geev)('N', jobvr, n, a->data(), lda, wr->data(), wi->data(), - NULL, 1, rv_data, ldvr, &wkopt, -1, &info); - - lwork = (int)wkopt; - work = THTensor_(newWithSize1d)(lwork); - - THLapack_(geev)('N', jobvr, n, a->data(), lda, wr->data(), wi->data(), - NULL, 1, rv_data, ldvr, work->data(), lwork, &info); - - THLapackCheckWithCleanup(" Lapack Error in %s : %d off-diagonal elements of an didn't converge to zero", - THCleanup(c10::raw::intrusive_ptr::decref(re__); - c10::raw::intrusive_ptr::decref(rv__); - c10::raw::intrusive_ptr::decref(a); - c10::raw::intrusive_ptr::decref(wi); - c10::raw::intrusive_ptr::decref(wr); - c10::raw::intrusive_ptr::decref(work);), - "geev", info,""); - } - - { - scalar_t *re_data = re__->data(); - scalar_t *wi_data = wi->data(); - scalar_t *wr_data = wr->data(); - for (i=0; i #include #include +#include // // This file contains pointwise operation functions and kernels that @@ -242,14 +243,11 @@ bool THC_pointwiseApply1(THCState* state, // (or vice versa), the contiguous tensor can be collapsed to one // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A) \ - kernelPointwiseApply1 \ - <<>>( \ - OffsetInfo \ - (aInfo), \ - (TYPE) totalElements, op); +#define HANDLE_CASE(TYPE, A) \ + kernelPointwiseApply1 \ + <<>>( \ + OffsetInfo(aInfo), (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define HANDLE_A_CASE(TYPE, A) { \ switch (A) { \ @@ -298,6 +296,7 @@ bool THC_pointwiseApply1(THCState* state, uint64_t, 1> <<>>( aOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 @@ -310,6 +309,7 @@ bool THC_pointwiseApply1(THCState* state, uint64_t, -1> <<>>( aOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE @@ -392,16 +392,13 @@ bool THC_pointwiseApply2(THCState* state, // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. #define HANDLE_CASE(TYPE, A, B) \ - kernelPointwiseApply2 \ + kernelPointwiseApply2 \ <<>>( \ - OffsetInfo \ - (aInfo), \ - OffsetInfo \ - (bInfo), \ - (TYPE) totalElements, op); + OffsetInfo(aInfo), \ + OffsetInfo(bInfo), \ + (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + #define HANDLE_B_CASE(TYPE, A, B) { \ switch (B) { \ @@ -474,6 +471,7 @@ bool THC_pointwiseApply2(THCState* state, uint64_t, 1, 1> <<>>( aOffset, bOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); @@ -488,6 +486,7 @@ bool THC_pointwiseApply2(THCState* state, uint64_t, -1, -1> <<>>( aOffset, bOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE @@ -598,7 +597,8 @@ bool THC_pointwiseApply3(THCState* state, (bInfo), \ OffsetInfo \ (cInfo), \ - (TYPE) totalElements, op); + (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define HANDLE_C_CASE(TYPE, A, B, C) { \ switch (C) { \ @@ -697,6 +697,7 @@ bool THC_pointwiseApply3(THCState* state, uint64_t, 1, 1, 1> <<>>( aOffset, bOffset, cOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); @@ -715,6 +716,7 @@ bool THC_pointwiseApply3(THCState* state, uint64_t, -1, -1, -1> <<>>( aOffset, bOffset, cOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index 6375891bd7f2..be0bf6ffa1ba 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -94,15 +94,16 @@ __device__ __forceinline__ int getLaneId() { #if defined(__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskLt() { - std::uint64_t m = (1ull << getLaneId()) - 1ull; + const std::uint64_t m = (1ull << getLaneId()) - 1ull; return m; +} #else __device__ __forceinline__ unsigned getLaneMaskLt() { unsigned mask; asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); return mask; -#endif } +#endif #if defined (__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskLe() { @@ -119,27 +120,28 @@ __device__ __forceinline__ unsigned getLaneMaskLe() { #if defined(__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskGt() { - std::uint64_t m = getLaneMaskLe(); + const std::uint64_t m = getLaneMaskLe(); return m ? ~m : m; +} #else __device__ __forceinline__ unsigned getLaneMaskGt() { unsigned mask; asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); return mask; -#endif } +#endif #if defined(__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskGe() { - std::uint64_t m = getLaneMaskLt(); + const std::uint64_t m = getLaneMaskLt(); return ~m; +} #else __device__ __forceinline__ unsigned getLaneMaskGe() { unsigned mask; asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); return mask; -#endif } - +#endif #endif // THC_ASM_UTILS_INC diff --git a/aten/src/THC/THCReduceAll.cuh b/aten/src/THC/THCReduceAll.cuh index 9546f85f61c9..af2e264e6528 100644 --- a/aten/src/THC/THCReduceAll.cuh +++ b/aten/src/THC/THCReduceAll.cuh @@ -10,6 +10,7 @@ // #include +#include #include #ifdef __HIP_PLATFORM_HCC__ @@ -209,6 +210,7 @@ void callReduceAll(THCState* state, <<>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, (AccT*) scratchSpace); + C10_CUDA_KERNEL_LAUNCH_CHECK(); int numPass1Blocks = grid.x; getPass2ReduceBlockGrid(state, totalElements, grid, block); @@ -218,6 +220,7 @@ void callReduceAll(THCState* state, <<>>( numPass1Blocks, init, reduceOp, (AccT*) scratchSpace, devOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); THCudaFree(state, scratchSpace); } else { @@ -227,6 +230,7 @@ void callReduceAll(THCState* state, kernelReduceAll <<>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, devOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } diff --git a/aten/src/THC/THCTensorSort.cu b/aten/src/THC/THCTensorSort.cu index 8969209a1bdc..189e73b909fb 100644 --- a/aten/src/THC/THCTensorSort.cu +++ b/aten/src/THC/THCTensorSort.cu @@ -1,5 +1,6 @@ #include #include +#include void THCudaLongTensor_fillSliceWithIndex(THCState* state, THCudaLongTensor* t, @@ -28,8 +29,10 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, #define FILL_INDEX(T, DIM) \ fillSliceWithIndex \ - <<>>( \ - info, numSlices, sliceSize, info.strides[collapseDim]) + <<>>( \ + info, numSlices, sliceSize, info.strides[collapseDim]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() + if (THCTensor_canUse32BitIndexMath(state, t)) { TensorInfo info = @@ -59,6 +62,5 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, } #undef FILL_INDEX - THCudaCheck(cudaGetLastError()); } } diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu index 07303fa47096..3f506d345714 100644 --- a/aten/src/THC/generic/THCTensorIndex.cu +++ b/aten/src/THC/generic/THCTensorIndex.cu @@ -3,6 +3,8 @@ #else #include +#include +#include // Check tensor dimensions for index operations, and return the slice size. // src can be nullptr in case of indexFill: in that case it is ignored. @@ -126,11 +128,12 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ - indexCopySmallIndex \ - <<>>( \ - dstInfo, srcInfo, indicesInfo, \ - dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); +#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + indexCopySmallIndex \ + <<>>( \ + dstInfo, srcInfo, indicesInfo, \ + dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, TYPE, \ DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ @@ -140,7 +143,8 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT dstInfo, srcInfo, indicesInfo, \ dstCopyDim, srcCopyDim, srcTotalSize, \ (IDX_IS_MAJOR) ? sliceSize : numIndices, \ - dstCopyDimSize); + dstCopyDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -279,6 +283,13 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); + at::assert_no_overlap(dst, indices); + if (at::has_internal_overlap(dst) == at::MemOverlap::YES) { + TORCH_WARN( + "Use of index_fill_ on expanded tensors is deprecated. " + "Please clone() the tensor before performing this operation. " + "This also applies to advanced indexing e.g. tensor[mask] = scalar"); + } // The `src` is partitioned into two parts: // -the size of each slice we are indexing, which is the @@ -299,11 +310,12 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \ +#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \ indexFillSmallIndex \ - <<>>( \ - dstInfo, indicesInfo, \ - dstFillDim, sliceSize, dstFillDimSize, val); + <<>>( \ + dstInfo, indicesInfo, \ + dstFillDim, sliceSize, dstFillDimSize, val); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM, IDX_IS_MAJOR) \ indexFillLargeIndex \ @@ -311,7 +323,8 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT dstInfo, indicesInfo, \ dstFillDim, sliceSize * numIndices, \ (IDX_IS_MAJOR) ? sliceSize : numIndices, \ - dstFillDimSize, val); + dstFillDimSize, val); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu index 8c0dac0aa686..216a96443887 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.cu +++ b/aten/src/THC/generic/THCTensorMathMagma.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.cu" #else +#include + #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) #ifdef USE_MAGMA @@ -171,8 +173,10 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper dim3 threads(128); if (uplo == 'U') { THCTensor_(copyUpperSymmetric)<<>>(input_data, n, len); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { THCTensor_(copyLowerSymmetric)<<>>(input_data, n, len); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } THCTensor_(freeCopyTo)(state, input, ra_); diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h index ae46a62c9ec6..6fc51f393135 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.h +++ b/aten/src/THC/generic/THCTensorMathMagma.h @@ -6,7 +6,6 @@ // MAGMA (i.e. CUDA implementation of LAPACK functions) THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); -THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, bool eigenvectors); THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper); THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_); diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu index 76f470ce7dfb..ce2f124215ca 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.cu +++ b/aten/src/THC/generic/THCTensorMathReduce.cu @@ -41,9 +41,11 @@ void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar dim3 threads(32); THCTensor_kernel_renorm - <<>> - (THCTensor_(data)(state, data), scalar_cast(value), size, scalar_cast(maxnorm)); + <<>>(THCTensor_(data)(state, data), + scalar_cast(value), size, scalar_cast(maxnorm)); + // Do not replace with C10_CUDA_KERNEL_LAUNCH_CHECK() yet as it exhibits different behaviour from THError(). + // THError() calls the an error handler, or throws std::runtime_error if a custom handler hasn't been registered. cudaError_t errcode = cudaGetLastError(); if(errcode != cudaSuccess) THError(cudaGetErrorString(errcode)); diff --git a/aten/src/THC/generic/THCTensorMode.cu b/aten/src/THC/generic/THCTensorMode.cu index 9fe955f3cf8d..8c428c9a5d1b 100644 --- a/aten/src/THC/generic/THCTensorMode.cu +++ b/aten/src/THC/generic/THCTensorMode.cu @@ -2,6 +2,7 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMode.cu" #else +#include #include void THCTensor_(calculateMode)(THCState *state, @@ -235,14 +236,14 @@ void THCTensor_(mode)(THCState *state, // Macro that calls kernel --> note that we set the block dimensions here, and // the amount of shared memory - #define HANDLE_MODE(SIZE) \ - { \ - dim3 blockSize(SIZE / 2); \ -\ - int memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ - computeMode \ - <<>>( \ - THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ + #define HANDLE_MODE(SIZE) \ + { \ + const dim3 blockSize(SIZE / 2); \ + const auto memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ + computeMode \ + <<>>( \ + THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } // Tradeoff between compilation time and the number of specializations. Ideally we would have diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu index f3ca8bf93b1b..1ef540ba3302 100644 --- a/aten/src/THC/generic/THCTensorRandom.cu +++ b/aten/src/THC/generic/THCTensorRandom.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) @@ -39,6 +40,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, larger_short), one, inputsize ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + at::Tensor smaller_short_wrapped = THTensor_wrap(smaller_short); at::Tensor smaller_wrapped = THTensor_wrap(smaller); at::Tensor larger_short_wrapped = THTensor_wrap(larger_short); @@ -57,6 +60,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, larger_short), inputsize - h_large_c, h_large_c ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + scalar_t q_max = at::max(THTensor_wrap(_q)).item(); condDiv<<< inputBlockDim, BLOCK_SIZE, 0, c10::cuda::getCurrentCUDAStream()>>>( @@ -64,6 +69,7 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, _J), inputsize, q_max ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); THCudaLongTensor_free(state, smaller); THCudaLongTensor_free(state, larger); @@ -104,6 +110,8 @@ void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, T THCTensor_(data)(state, uniform), THCTensor_(data)(state, bernoulli) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + THCTensor_(free)(state, uniform); THCTensor_(free)(state, bernoulli); } diff --git a/aten/src/THC/generic/THCTensorScatterGather.cu b/aten/src/THC/generic/THCTensorScatterGather.cu index 832539d370ce..a1ab8d63f163 100644 --- a/aten/src/THC/generic/THCTensorScatterGather.cu +++ b/aten/src/THC/generic/THCTensorScatterGather.cu @@ -2,10 +2,13 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorScatterGather.cu" #else +#include + #define RUN(TYPE, DIMS, REAL) \ - THCudaTensor_gatherKernel \ - <<>>( \ - tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); + THCudaTensor_gatherKernel \ + <<>>( \ + tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index) { @@ -61,19 +64,15 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, switch (indexInfo.dims) { case 1: RUN(unsigned int, 1, scalar_t); - THCudaCheck(cudaGetLastError()); break; case 2: RUN(unsigned int, 2, scalar_t); - THCudaCheck(cudaGetLastError()); break; case 3: RUN(unsigned int, 3, scalar_t); - THCudaCheck(cudaGetLastError()); break; default: RUN(unsigned int, -1, scalar_t); - THCudaCheck(cudaGetLastError()); break; } } else { @@ -84,7 +83,6 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, TensorInfo indexInfo = getTensorInfo(state, index); RUN(uint64_t, -1, scalar_t); - THCudaCheck(cudaGetLastError()); } } diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu index b4da00a98b7f..e378fe03358e 100644 --- a/aten/src/THC/generic/THCTensorSort.cu +++ b/aten/src/THC/generic/THCTensorSort.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorSort.cu" #else +#include + // In alignment with default sort on a c++ map, this function // will permute key and value tensors identically, and // in such a way that the 'key' tensor is ordered numerically @@ -53,8 +55,9 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, dim3 block(blockSize); \ \ if (dir) { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace, TYPE, SIZE> \ + <<>>( \ keyInfo, \ keySlices, \ (TYPE) keySliceSize, \ @@ -62,16 +65,19 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ GTComp()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace, TYPE, SIZE> \ + <<>>( \ keyInfo, \ keySlices, \ (TYPE) keySliceSize, \ (TYPE) keyInfo.strides[collapseKeyDim], \ valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ - LTComp()); \ + LTComp()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } while (0) @@ -147,8 +153,6 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, #undef HANDLE_CASE #undef HANDLE_SORT_CASE #undef HANDLE_A_CASE - - THCudaCheck(cudaGetLastError()); } void THCTensor_(sortViaThrust)(THCState* state, diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index 357b3f2e22f3..8d7bf7701c04 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -3,6 +3,7 @@ #else #include +#include void THCTensor_(topk)(THCState* state, THCTensor *topK, @@ -37,8 +38,8 @@ void THCTensor_(topk)(THCState* state, // is provided to the kernel for the arguments. #define RUN_K(INDEX_T, DIM, DIR) \ - gatherTopK \ - <<>>( \ + gatherTopK \ + <<>>( \ inputInfo, \ static_cast(sliceSize), \ static_cast(k), \ @@ -50,7 +51,8 @@ void THCTensor_(topk)(THCState* state, static_cast(topKSlices), \ static_cast(topKInfo.strides[collapseTopKDim]), \ indicesInfo, \ - static_cast(indicesInfo.strides[collapseIndicesDim])) + static_cast(indicesInfo.strides[collapseIndicesDim])); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() #define RUN_DIR(INDEX_T, DIM) \ if (dir) { \ @@ -71,10 +73,10 @@ void THCTensor_(topk)(THCState* state, } #define RUN_T(INDEX_T) \ - TensorInfo inputInfo = \ - getTensorInfo(state, input); \ - TensorInfo topKInfo = \ - getTensorInfo(state, topK); \ + TensorInfo inputInfo = \ + getTensorInfo(state, input); \ + TensorInfo topKInfo = \ + getTensorInfo(state, topK); \ TensorInfo indicesInfo = \ getTensorInfo(state, indices); \ \ diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py index 4fb7fffb5a5d..b0534bd9722d 100644 --- a/benchmarks/operator_benchmark/benchmark_caffe2.py +++ b/benchmarks/operator_benchmark/benchmark_caffe2.py @@ -50,10 +50,15 @@ def tensor(self, shapes, dtype='float32', device='cpu'): Return: C2 tensor of dtype """ + return self.feed_tensor(benchmark_utils.numpy_random(dtype, *shapes), device) + + def feed_tensor(self, tensor, device='cpu'): + """ Similar to tensor, but can supply any data compatible with FeedBlob + """ blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index) dev = self._device_option(device) with core.DeviceScope(dev): - workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes)) + workspace.FeedBlob(blob_name, tensor) Caffe2BenchmarkBase.tensor_index += 1 return blob_name diff --git a/benchmarks/operator_benchmark/c2/batch_gather_test.py b/benchmarks/operator_benchmark/c2/batch_gather_test.py new file mode 100644 index 000000000000..ff3d84b99b2b --- /dev/null +++ b/benchmarks/operator_benchmark/c2/batch_gather_test.py @@ -0,0 +1,56 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core +import numpy + + +"""Microbenchmarks for element-wise BatchGather operator.""" + +# Configs for C2 BatherGather operator +batch_gather_configs_short = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [8, 8, 1], + [256, 512, 1], + [512, 512, 1], + [8, 8, 2], + [256, 512, 2], + [512, 512, 2], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"] +) + +batch_gather_configs_long = op_bench.cross_product_configs( + M=[128, 1024], + N=[128, 1024], + K=[1, 2], + device=['cpu', 'cuda'], + tags=["long"] +) + +class BatchGatherBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, M, N, K, device): + self.input_one = self.tensor([M, N, K], device=device) + max_val = N + numpy.random.seed((1 << 32) - 1) + index_dim = numpy.random.randint(0, N) + self.index = self.feed_tensor(numpy.random.randint(0, max_val, index_dim), device=device) + self.output = self.tensor([M, index_dim, K], device=device) + self.set_module_name("batch_gather") + + def forward(self): + op = core.CreateOperator("BatchGather", [self.input_one, self.index], self.output) + return op + + +op_bench_c2.generate_c2_test( + batch_gather_configs_long + batch_gather_configs_short, BatchGatherBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/index_select_test.py b/benchmarks/operator_benchmark/pt/index_select_test.py new file mode 100644 index 000000000000..8418edb2840b --- /dev/null +++ b/benchmarks/operator_benchmark/pt/index_select_test.py @@ -0,0 +1,57 @@ +import operator_benchmark as op_bench +import torch +import numpy + + +"""Microbenchmarks for index_select operator.""" + +# An example input from this configuration is M=4, N=4, dim=0. +index_select_configs_short = op_bench.config_list( + attr_names=["M", "N", "K", "dim"], + attrs=[ + [8, 8, 1, 1], + [256, 512, 1, 1], + [512, 512, 1, 1], + [8, 8, 2, 1], + [256, 512, 2, 1], + [512, 512, 2, 1], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"] +) + + +index_select_configs_long = op_bench.cross_product_configs( + M=[128, 1024], + N=[128, 1024], + K=[1, 2], + dim=[1], + device=['cpu', 'cuda'], + tags=["long"] +) + + +class IndexSelectBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, dim, device): + max_val = N + numpy.random.seed((1 << 32) - 1) + index_dim = numpy.random.randint(0, N) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device), + "dim" : dim, + "index" : torch.tensor(numpy.random.randint(0, max_val, index_dim), device=device), + } + self.set_module_name("index_select") + + def forward(self, input_one, dim, index): + return torch.index_select(input_one, dim, index) + + +op_bench.generate_pt_test(index_select_configs_short + index_select_configs_long, + IndexSelectBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt_extension/extension.cpp b/benchmarks/operator_benchmark/pt_extension/extension.cpp index 2e665604c6ed..2dbdfdd8b3e6 100644 --- a/benchmarks/operator_benchmark/pt_extension/extension.cpp +++ b/benchmarks/operator_benchmark/pt_extension/extension.cpp @@ -17,9 +17,10 @@ List consume_list(List a) { // That caused an issue for our op benchmark which needs to run an op // in a loop and report the execution time. This diff resolves that issue by // registering this consume op with correct alias information which is DEFAULT. -auto reg = torch::RegisterOperators() - .op("operator_benchmark::_consume", &consume) - .op("operator_benchmark::_consume.list", &consume_list); +TORCH_LIBRARY_FRAGMENT(operator_benchmark, m) { + m.def("_consume", &consume); + m.def("_consume.list", &consume_list); +} PYBIND11_MODULE(cpp_extension, m) { m.def("_consume", &consume, "consume"); diff --git a/benchmarks/static_runtime/deep_wide_pt.h b/benchmarks/static_runtime/deep_wide_pt.h index d6eae2f8b4ca..c473eaf1bb95 100644 --- a/benchmarks/static_runtime/deep_wide_pt.h +++ b/benchmarks/static_runtime/deep_wide_pt.h @@ -50,7 +50,7 @@ struct DeepAndWideFast : torch::nn::Module { torch::Tensor wide) { torch::NoGradGuard no_grad; if (!allocated) { - auto wide_offset = at::native::add(wide, mu_); + auto wide_offset = at::add(wide, mu_); auto wide_normalized = at::native::mul(wide_offset, sigma_); // Placeholder for ReplaceNaN auto wide_preproc = at::native::clamp(wide_normalized, -10.0, 10.0); @@ -82,7 +82,7 @@ struct DeepAndWideFast : torch::nn::Module { } else { // Potential optimization: add and mul could be fused together (e.g. with // Eigen). - at::native::add_out(prealloc_tensors[0], wide, mu_); + at::add_out(prealloc_tensors[0], wide, mu_); at::native::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_); at::native::clamp_out( diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 07f9ac253b9f..251e2654b013 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1,4 +1,5 @@ #include +#include #include #include "deep_wide_pt.h" #include "test_scripts.h" @@ -249,3 +250,34 @@ TEST(StaticRuntime, CleanUpMemory) { } } } + +TEST(StaticRuntime, FusionPass) { + const int embedding_size = 32; + const int num_features = 50; + for (int batch_size : {1, 8, 32}) { + for (int i = 0; i < 2; ++i) { + torch::jit::Module module = getDeepAndWideSciptModel(); + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + // run jit graph executor + std::vector inputs({ad_emb_packed, user_emb, wide}); + auto output_1 = getTensor(module.forward(inputs)); + + Method method = module.get_method("forward"); + auto graph = method.graph(); + fuseStaticSubgraphs(graph); + bool hit = false; + for (const auto& n : module.get_method("forward").graph()->nodes()) { + if (n->kind() == torch::jit::prim::StaticSubgraph) { + hit = true; + } + } + EXPECT_TRUE(hit); + auto output_2 = getTensor(module.forward(inputs)); + EXPECT_TRUE(output_1.equal(output_2)); + } + } +} + diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index d924003b9270..d47cedada40f 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -7,61 +7,55 @@ #include #include -C10_DEFINE_int(iter, 100, "Number of iterations"); -C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations"); +C10_DEFINE_int(iter, 10000, "Number of iterations"); C10_DEFINE_int(sampled_iter, 10e6, "Number of iterations for the sampled observer benchmark"); namespace { -const int kInnerIter = 100; -const int kNumSampledCb = 2; const int kTensorSize = 16; const int kSmallTensorSize = 1; -const float kSampingProb = 0.1; - const float kLowSamplingProb = 0.0001; } -void setupBenchmarkCallbacks() { - at::enableRecordFunction(); - at::clearCallbacks(); - // non-sampled callback - at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) {}, - [](const at::RecordFunction&) {}) - .needsInputs(true)); - - // sampled - for (auto idx = 0; idx < kNumSampledCb; ++idx) { - at::addGlobalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) {}, - [](const at::RecordFunction&) {}) - .needsInputs(true) - .samplingProb(kSampingProb) - ); +void addTestCallback( + double sampling_prob = 1.0, + std::function(const at::RecordFunction&)> fn = + [](const at::RecordFunction&) { return nullptr; }) { + auto cb = at::RecordFunctionCallback( + std::move(fn), + [](const at::RecordFunction&, at::ObserverContext*) {}) + .needsInputs(false); + if (sampling_prob < 1.0) { + cb.samplingProb(sampling_prob); } + at::addGlobalCallback(cb); } -float runTensorBench(int tensor_size, int outer_iter) { +float runTensorGEMMBench(int tensor_size, int iter) { typedef std::chrono::high_resolution_clock clock; typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); - for (auto idx = 0; idx < kInnerIter * outer_iter; ++idx) { - torch::mm( - torch::randn({tensor_size, tensor_size}), - torch::randn({tensor_size, tensor_size})); + auto inp = torch::randn({tensor_size, tensor_size}); + for (auto idx = 0; idx < iter; ++idx) { + torch::mm(inp, inp); } auto duration = static_cast( std::chrono::duration_cast(clock::now() - start_time).count()); return duration; } -float runPureRecordFunctionBench(int outer_iter) { +float runPureRecordFunctionBench(int iter) { typedef std::chrono::high_resolution_clock clock; typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); - for (auto n = 0; n < outer_iter; ++n) { - RECORD_USER_SCOPE("test"); + for (auto idx = 0; idx < iter; ++idx) { + bool pre_sampled = false; + if (at::shouldRunRecordFunction(&pre_sampled)) { + at::RecordFunction guard(at::RecordScope::USER_SCOPE, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + guard.before("Test", -1); + } + } } auto duration = static_cast( std::chrono::duration_cast(clock::now() - start_time).count()); @@ -71,18 +65,19 @@ float runPureRecordFunctionBench(int outer_iter) { void runBenchmark() { float duration = 0; for (auto tensor_size : std::set({kSmallTensorSize, kTensorSize})) { - duration = runTensorBench(tensor_size, FLAGS_iter); - std::cout << "Running tensor benchmark, time per iteration (" + duration = runTensorGEMMBench(tensor_size, FLAGS_iter); + std::cout << "Tensor GEMM benchmark (" << tensor_size << "x" << tensor_size - << "): " << (duration/FLAGS_iter) + << ", " << FLAGS_iter << "): " << duration << " us." << std::endl; } - duration = runPureRecordFunctionBench(FLAGS_iter * 100); - std::cout << "Running pure RecordFunction benchmark, time per iteration: " - << (duration/FLAGS_iter) - << " us." << std::endl; + duration = runPureRecordFunctionBench(FLAGS_iter); + std::cout << "Pure RecordFunction benchmark (" + << FLAGS_iter << "): " + << duration + << " us." << std::endl; } int main(int argc, char** argv) { @@ -91,32 +86,39 @@ int main(int argc, char** argv) { return -1; } - auto duration = runTensorBench(kSmallTensorSize, FLAGS_warmup_iter); - std::cout << "Warmup time: " << duration << " us." << std::endl; + at::enableRecordFunction(); + at::clearCallbacks(); - setupBenchmarkCallbacks(); - std::cout << "Running with empty observers" << std::endl; + std::cout << "Warm up" << std::endl; runBenchmark(); - at::clearCallbacks(); std::cout << "Running without observers" << std::endl; runBenchmark(); - std::cout << "Running sampled observer benchmark" << std::endl; + addTestCallback(); + std::cout << "Running with empty non-sampled observer" << std::endl; + runBenchmark(); + at::clearCallbacks(); + + addTestCallback(kLowSamplingProb); + std::cout << "Running with empty sampled observer" << std::endl; + runBenchmark(); + at::clearCallbacks(); + + std::cout << "Checking number of sampled observer invocations" << std::endl; int cb_count = 0; - at::addGlobalCallback(at::RecordFunctionCallback( + addTestCallback( + kLowSamplingProb, [&](const at::RecordFunction& fn) { ++cb_count; - }, - [](const at::RecordFunction&) {}) - .needsInputs(true) - .samplingProb(kLowSamplingProb) + return nullptr; + } ); - runPureRecordFunctionBench(FLAGS_sampled_iter); + auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter); std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter - << " iterations " << duration + << " iterations: " << duration << " us, number of callback invocations: " << cb_count << ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb) << " invocations" << std::endl; diff --git a/c10/core/Device.h b/c10/core/Device.h index 7827119bb0ac..04cd711c37b2 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -93,9 +93,13 @@ struct C10_API Device final { DeviceType type_; DeviceIndex index_ = -1; void validate() { - TORCH_CHECK(index_ == -1 || index_ >= 0, + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(index_ == -1 || index_ >= 0, "Device index must be -1 or non-negative, got ", (int)index_); - TORCH_CHECK(!is_cpu() || index_ <= 0, + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_cpu() || index_ <= 0, "CPU device index must be -1 or zero, got ", (int)index_); } }; diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 1e9d85211f6d..486272ece92e 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -61,8 +61,8 @@ class DispatchKeySet final { } } // Test if a DispatchKey is in the set - bool has(DispatchKey t) const { - TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); + bool inline has(DispatchKey t) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); return static_cast(repr_ & DispatchKeySet(t).repr_); } // Test if DispatchKeySet is a superset of ks. diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index a9e8f1f6853f..e305f352d7cb 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -293,7 +293,8 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto impl = c10::make_intrusive( - Storage(storage()), key_set_, data_type_); + // No need to populate Storage; copy_tensor_metadata will do it for us. + key_set_, data_type_, device_opt_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -308,7 +309,8 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { auto impl = c10::make_intrusive( - Storage(storage()), key_set_, data_type_); + // No need to populate Storage; copy_tensor_metadata will do it for us. + key_set_, data_type_, device_opt_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 269976a7e148..5deab2a09832 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1706,7 +1706,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { is_channels_last_contiguous_ = false; is_channels_last_3d_ = false; is_channels_last_3d_contiguous_ = false; - is_non_overlapping_and_dense_ = false; + is_non_overlapping_and_dense_ = true; is_wrapped_number_ = false; allow_tensor_metadata_change_ = true; reserved_ = false; diff --git a/c10/test/util/Metaprogramming_test.cpp b/c10/test/util/Metaprogramming_test.cpp index 88c8e0facad1..63613980079d 100644 --- a/c10/test/util/Metaprogramming_test.cpp +++ b/c10/test/util/Metaprogramming_test.cpp @@ -476,4 +476,22 @@ namespace test_tuple_concat { } } +namespace test_concat_iseq { + using std::index_sequence; + using std::integer_sequence; + static_assert(std::is_same, concat_iseq_t<>>::value, ""); + static_assert(std::is_same, concat_iseq_t>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4>, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<2>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4, 2>, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4, 2>, index_sequence<9>>>::value, ""); + + static_assert(std::is_same, concat_iseq_t, integer_sequence>>::value, ""); +} + + } diff --git a/c10/util/Exception.cpp b/c10/util/Exception.cpp index 3f2c34ffae6c..1c6363326343 100644 --- a/c10/util/Exception.cpp +++ b/c10/util/Exception.cpp @@ -76,6 +76,14 @@ void Error::add_context(std::string new_msg) { refresh_what(); } +namespace detail { + +void torchCheckFail(const char *func, const char *file, uint32_t line, const std::string& msg) { + throw ::c10::Error({func, file, line}, msg); +} + +} // namespace detail + namespace Warning { namespace { diff --git a/c10/util/Exception.h b/c10/util/Exception.h index fed17a4cf526..ebd1e872251e 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -194,7 +194,7 @@ C10_API std::string GetExceptionString(const std::exception& e); namespace detail { // Return x if it is non-empty; otherwise return y. -inline std::string if_empty_then(std::string x, std::string y) { +inline std::string if_empty_then(const std::string& x, const std::string& y) { if (x.empty()) { return y; } else { @@ -324,27 +324,45 @@ inline std::string if_empty_then(std::string x, std::string y) { TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__) #ifdef STRIP_ERROR_MESSAGES -#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ - if (C10_UNLIKELY_OR_CONST(!(cond))) { \ - C10_THROW_ERROR(Error, \ - #cond #type " CHECK FAILED at " \ - C10_STRINGIZE(__FILE__) \ - ); \ +#define TORCH_CHECK_MSG(cond, type, ...) \ + (#cond #type " CHECK FAILED at " \ + C10_STRINGIZE(__FILE__)) +#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + C10_THROW_ERROR(Error, \ + TORCH_CHECK_MSG(cond, type, __VA_ARGS__) \ + ); \ } #else +#define TORCH_CHECK_MSG(cond, type, ...) \ + ::c10::detail::if_empty_then( \ + ::c10::str(__VA_ARGS__), \ + "Expected " #cond " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)" \ + ) #define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ if (C10_UNLIKELY_OR_CONST(!(cond))) { \ C10_THROW_ERROR(error_t, \ - ::c10::detail::if_empty_then( \ - ::c10::str(__VA_ARGS__), \ - "Expected " #cond " to be true, but got false. " \ - "(Could this error message be improved? If so, " \ - "please report an enhancement request to PyTorch.)" \ - ) \ + TORCH_CHECK_MSG(cond, type, __VA_ARGS__) \ ); \ } #endif -#define TORCH_CHECK(cond, ...) TORCH_CHECK_WITH(Error, cond, __VA_ARGS__) + +namespace c10 { +namespace detail { + +[[noreturn]] C10_API void torchCheckFail(const char *func, const char *file, uint32_t line, const std::string& msg); + +} // namespace detail +} // namespace 10 + +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, __FILE__, static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } // An utility macro that does what `TORCH_CHECK` does if compiled in the host code, // otherwise does nothing. Supposed to be used in the code shared between host and diff --git a/c10/util/Metaprogramming.h b/c10/util/Metaprogramming.h index ae929a93ca09..a56b43afa852 100644 --- a/c10/util/Metaprogramming.h +++ b/c10/util/Metaprogramming.h @@ -309,4 +309,29 @@ template } +/** + * Concatenate multiple integer sequences + * Example: + * concat_iseq_t, std::index_sequence<4, 2>, std::index_sequence<5>> + * == std::index_sequence<2, 5, 3, 4, 2, 5> + */ +template struct concat_iseq { + static_assert(false_t::value, "In concat_iseq, the T arguments each must be std::integer_sequence<...> with the same IntType."); +}; +template<> +struct concat_iseq<> { + using type = std::index_sequence<>; +}; +template +struct concat_iseq> { + using type = std::integer_sequence; +}; +template +struct concat_iseq, std::integer_sequence, TailISeqs...> { + using type = typename concat_iseq, TailISeqs...>::type; +}; +template +using concat_iseq_t = typename concat_iseq::type; + + }} diff --git a/c10/util/Unicode.h b/c10/util/Unicode.h new file mode 100644 index 000000000000..9cce93cc9b83 --- /dev/null +++ b/c10/util/Unicode.h @@ -0,0 +1,29 @@ +#pragma once + +#if defined(_WIN32) +#include +#include +#include +#endif + +namespace c10 { +#if defined(_WIN32) +inline std::wstring u8u16(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + int size_needed = MultiByteToWideChar( + CP_UTF8, 0, str.c_str(), static_cast(str.size()), NULL, 0); + TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + std::wstring wstr(size_needed, 0); + MultiByteToWideChar( + CP_UTF8, + 0, + str.c_str(), + static_cast(str.size()), + &wstr[0], + size_needed); + return wstr; +} +#endif +} diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py index 36d6ba73e0c3..f992c6f9e1fc 100644 --- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py @@ -27,7 +27,7 @@ class LayerNorm(serial.SerializedTestCase): epsilon=st.floats(min_value=1e-4, max_value=1e-3), elementwise_affine=st.booleans()) @settings(deadline=datetime.timedelta(seconds=10)) - def Skip_test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): + def test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace workspace.ResetWorkspace() @@ -142,7 +142,7 @@ def _layernorm_transform(self, X): elementwise_affine=st.booleans()) @settings(deadline=datetime.timedelta(seconds=10)) # re-enable when T74553975 gets fixed - def Skip_test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine): + def test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace diff --git a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py index c5aea77d7199..041dcce97dbf 100644 --- a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py @@ -418,6 +418,147 @@ def test_small_sls(self, seed): ) assert 0 + @given(seed=st.integers(0, 65535)) + @settings(deadline=datetime.timedelta(seconds=10)) + def test_sls_layernorm(self, seed): + np.random.seed(seed) + workspace.ResetWorkspace() + + n = 2 + DIM = 3 + data = 4 * (np.random.random_sample((n, DIM)) + 1).astype(np.float32) + + lengths = np.array([n], dtype=np.int32) + indices = np.array(range(n), dtype=np.int64) + weights = np.random.uniform(low=0.01, high=0.5, size=[n]).astype(np.float32) + + pred_net = caffe2_pb2.NetDef() + pred_net.name = "pred" + pred_net.external_input.extend( + ["quantized_data", "weights", "indices", "lengths"] + ) + pred_net.external_output.append("Y_norm") + pred_net.external_output.append("Y_mean") + pred_net.external_output.append("Y_std") + + pred_net.op.add().CopyFrom( + core.CreateOperator( + "SparseLengthsWeightedSumFused8BitRowwise", + ["quantized_data", "weights", "indices", "lengths"], + ["Y"], + ) + ) + + pred_net.op.add().CopyFrom( + core.CreateOperator( + "LayerNorm", + ["Y"], + ["Y_norm", "Y_mean", "Y_std"], + epsilon=1e-4, + ) + ) + + ref_net = caffe2_pb2.NetDef() + ref_net.name = "ref" + ref_net.external_input.extend( + ["quantized_data", "weights", "indices", "lengths"] + ) + ref_net.external_output.append("Y_norm") + ref_net.external_output.append("Y_mean") + ref_net.external_output.append("Y_std") + + ref_net.op.add().CopyFrom( + core.CreateOperator( + "SparseLengthsWeightedSumFused8BitRowwiseFakeFP16NNPI", + ["quantized_data", "weights", "indices", "lengths"], + ["Y"], + ) + ) + + ref_net.op.add().CopyFrom( + core.CreateOperator( + "LayerNormFakeFP16NNPI", + ["Y"], + ["Y_norm", "Y_mean", "Y_std"], + epsilon=1e-4, + axis=1, + elementwise_affine=False + ) + ) + + workspace.FeedBlob("data", data) + workspace.RunOperatorOnce( + core.CreateOperator( + "FloatToFused8BitRowwiseQuantized", ["data"], ["quantized_data"] + ) + ) + + quantized_data = workspace.FetchBlob("quantized_data") + + onnxified_net = onnxifi_caffe2_net( + pred_net, + {}, + max_batch_size=1, + max_seq_size=n, + debug=True, + adjust_batch=True, + use_onnx=False, + ) + print("before", pred_net) + print("after", onnxified_net) + workspace.FeedBlob("indices", indices) + workspace.FeedBlob("lengths", lengths) + workspace.FeedBlob("weights", weights) + + workspace.CreateNet(onnxified_net) + workspace.CreateNet(ref_net) + + workspace.RunNet(onnxified_net.name) + Y_glow = workspace.FetchBlob("Y_norm") + Y_mean_glow = workspace.FetchBlob("Y_mean") + Y_std_glow = workspace.FetchBlob("Y_std") + + workspace.RunNet(ref_net.name) + Y = workspace.FetchBlob("Y") + print("pre normalization", Y) + Y_ref = workspace.FetchBlob("Y_norm") + Y_mean_ref = workspace.FetchBlob("Y_mean") + Y_std_ref = workspace.FetchBlob("Y_std") + + # print(Y_ref, Y_glow) + # print(Y_ref.shape, Y_glow.shape) + + diff = np.abs(Y_ref - Y_glow) + max_err = np.max(diff, axis=1) + num_offenders = (max_err > 0).sum() + if num_offenders > 0: + np.set_printoptions(precision=12) + print( + "ref", + Y_ref.astype(np.float16).astype(np.float32), + "glow", + Y_glow.astype(np.float16).astype(np.float32), + ) + print_test_debug_info( + "slws_fused_8bit_rowwise_inv_scale", + { + "seed": seed, + "indices": indices, + "data": data, + "quantized_data": quantized_data, + "lengths": lengths, + "weights": weights, + "Y_norm_glow": Y_glow, + "Y_norm_ref": Y_ref, + "Y_mean_glow": Y_mean_glow, + "Y_std_glow": Y_std_glow, + "Y_mean_ref": Y_mean_ref, + "Y_std_ref": Y_std_ref, + "diff": diff, + "rowwise_diff": np.max(diff, axis=1), + }, + ) + assert 0 if __name__ == '__main__': diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index dd422c5b44cc..fcf08eebfa8a 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -26,17 +26,6 @@ C10_DEFINE_bool( false, "Serialize BOOL, UINT8, INT8, UINT16, INT16, INT64, FLOAT16 tensors using byte_data field instead of int32"); -#ifdef _MSC_VER -// It's MSVC, so we just have to guess ... and allow an override -#ifdef FOLLY_ENDIAN_BE -constexpr auto kIsLittleEndian = false; -#else -constexpr auto kIsLittleEndian = true; -#endif -#else -constexpr auto kIsLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; -#endif - namespace caffe2 { /** * @brief StringSerializer is the serializer for String. @@ -420,7 +409,7 @@ void DeserializeBlob(const BlobProto& blob_proto, Blob* result) { // === Local helper functions === // Get dimensions from Tensor proto -static std::vector DimsFromTensorProto(const TensorProto& proto) { +std::vector DimsFromTensorProto(const TensorProto& proto) { std::vector dims; dims.reserve(proto.dims().size()); for (const int64_t d : proto.dims()) { @@ -430,7 +419,7 @@ static std::vector DimsFromTensorProto(const TensorProto& proto) { } // Get number of elements from Tensor proto -static int64_t NumelFromTensorProto(const TensorProto& tensor_proto) { +int64_t NumelFromTensorProto(const TensorProto& tensor_proto) { int64_t numel = 1; for (const int64_t d : tensor_proto.dims()) { numel *= d; @@ -439,7 +428,7 @@ static int64_t NumelFromTensorProto(const TensorProto& tensor_proto) { } // Get data type from Tensor proto -static TypeMeta GetDataType(const TensorProto& tensor_proto) { +TypeMeta GetDataType(const TensorProto& tensor_proto) { TypeMeta dtype; if (tensor_proto.data_type() != TensorProto_DataType_UNDEFINED) { dtype = DataTypeToTypeMeta(tensor_proto.data_type()); @@ -459,7 +448,7 @@ static at::TensorOptions TensorOptionsFromProto( .device(OptionToDevice(tensor_proto.device_detail())); } -static std::unique_ptr ContextFromProto( +std::unique_ptr ContextFromProto( const TensorProto& tensor_proto) { auto device = OptionToDevice(tensor_proto.device_detail()); return CreateContext(device); diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h index 5309314af0c7..72d148c86775 100644 --- a/caffe2/core/blob_serialization.h +++ b/caffe2/core/blob_serialization.h @@ -17,6 +17,17 @@ C10_DECLARE_int(caffe2_tensor_chunk_size); C10_DECLARE_int(caffe2_max_tensor_serializer_threads); C10_DECLARE_bool(caffe2_serialize_fp16_as_bytes); +#ifdef _MSC_VER +// It's MSVC, so we just have to guess ... and allow an override +#ifdef FOLLY_ENDIAN_BE +constexpr auto kIsLittleEndian = false; +#else +constexpr auto kIsLittleEndian = true; +#endif +#else +constexpr auto kIsLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; +#endif + namespace caffe2 { constexpr auto kTensorBlobType = "Tensor"; @@ -239,6 +250,14 @@ inline std::string SerializeBlobProtoAsString_EnforceCheck( return SerializeAsString_EnforceCheck(blob, blob.name().c_str()); } +int64_t NumelFromTensorProto(const TensorProto& tensor_proto); + +std::vector DimsFromTensorProto(const TensorProto& proto); + +TypeMeta GetDataType(const TensorProto& tensor_proto); + +std::unique_ptr ContextFromProto(const TensorProto& tensor_proto); + } // namespace caffe2 #endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_ diff --git a/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h b/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h index cc0ad1f72d01..04495fa0cd72 100644 --- a/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h +++ b/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h @@ -6,7 +6,7 @@ extern "C" { #endif /* -** Copyright (c) 2015-2016 The Khronos Group Inc. +** Copyright (c) 2015-2017 The Khronos Group Inc. ** ** Licensed under the Apache License, Version 2.0 (the "License"); ** you may not use this file except in compliance with the License. @@ -28,22 +28,22 @@ extern "C" { #define VK_VERSION_1_0 1 -#include "vk_platform.h" +#include "./vk_platform.h" #define VK_MAKE_VERSION(major, minor, patch) \ (((major) << 22) | ((minor) << 12) | (patch)) // DEPRECATED: This define has been removed. Specific version defines (e.g. VK_API_VERSION_1_0), or the VK_MAKE_VERSION macro, should be used instead. -//#define VK_API_VERSION VK_MAKE_VERSION(1, 0, 0) +//#define VK_API_VERSION VK_MAKE_VERSION(1, 0, 0) // Patch version should always be set to 0 // Vulkan 1.0 version number -#define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0) +#define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0)// Patch version should always be set to 0 #define VK_VERSION_MAJOR(version) ((uint32_t)(version) >> 22) #define VK_VERSION_MINOR(version) (((uint32_t)(version) >> 12) & 0x3ff) #define VK_VERSION_PATCH(version) ((uint32_t)(version) & 0xfff) // Version of this file -#define VK_HEADER_VERSION 29 +#define VK_HEADER_VERSION 59 #define VK_NULL_HANDLE 0 @@ -145,6 +145,8 @@ typedef enum VkResult { VK_ERROR_INCOMPATIBLE_DISPLAY_KHR = -1000003001, VK_ERROR_VALIDATION_FAILED_EXT = -1000011001, VK_ERROR_INVALID_SHADER_NV = -1000012000, + VK_ERROR_OUT_OF_POOL_MEMORY_KHR = -1000069000, + VK_ERROR_INVALID_EXTERNAL_HANDLE_KHR = -1000072003, VK_RESULT_BEGIN_RANGE = VK_ERROR_FRAGMENTED_POOL, VK_RESULT_END_RANGE = VK_INCOMPLETE, VK_RESULT_RANGE_SIZE = (VK_INCOMPLETE - VK_ERROR_FRAGMENTED_POOL + 1), @@ -220,12 +222,117 @@ typedef enum VkStructureType { VK_STRUCTURE_TYPE_DEDICATED_ALLOCATION_IMAGE_CREATE_INFO_NV = 1000026000, VK_STRUCTURE_TYPE_DEDICATED_ALLOCATION_BUFFER_CREATE_INFO_NV = 1000026001, VK_STRUCTURE_TYPE_DEDICATED_ALLOCATION_MEMORY_ALLOCATE_INFO_NV = 1000026002, + VK_STRUCTURE_TYPE_TEXTURE_LOD_GATHER_FORMAT_PROPERTIES_AMD = 1000041000, + VK_STRUCTURE_TYPE_RENDER_PASS_MULTIVIEW_CREATE_INFO_KHX = 1000053000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MULTIVIEW_FEATURES_KHX = 1000053001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MULTIVIEW_PROPERTIES_KHX = 1000053002, VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO_NV = 1000056000, VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_NV = 1000056001, VK_STRUCTURE_TYPE_IMPORT_MEMORY_WIN32_HANDLE_INFO_NV = 1000057000, VK_STRUCTURE_TYPE_EXPORT_MEMORY_WIN32_HANDLE_INFO_NV = 1000057001, VK_STRUCTURE_TYPE_WIN32_KEYED_MUTEX_ACQUIRE_RELEASE_INFO_NV = 1000058000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2_KHR = 1000059000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2_KHR = 1000059001, + VK_STRUCTURE_TYPE_FORMAT_PROPERTIES_2_KHR = 1000059002, + VK_STRUCTURE_TYPE_IMAGE_FORMAT_PROPERTIES_2_KHR = 1000059003, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_IMAGE_FORMAT_INFO_2_KHR = 1000059004, + VK_STRUCTURE_TYPE_QUEUE_FAMILY_PROPERTIES_2_KHR = 1000059005, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2_KHR = 1000059006, + VK_STRUCTURE_TYPE_SPARSE_IMAGE_FORMAT_PROPERTIES_2_KHR = 1000059007, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SPARSE_IMAGE_FORMAT_INFO_2_KHR = 1000059008, + VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHX = 1000060000, + VK_STRUCTURE_TYPE_BIND_BUFFER_MEMORY_INFO_KHX = 1000060001, + VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_INFO_KHX = 1000060002, + VK_STRUCTURE_TYPE_DEVICE_GROUP_RENDER_PASS_BEGIN_INFO_KHX = 1000060003, + VK_STRUCTURE_TYPE_DEVICE_GROUP_COMMAND_BUFFER_BEGIN_INFO_KHX = 1000060004, + VK_STRUCTURE_TYPE_DEVICE_GROUP_SUBMIT_INFO_KHX = 1000060005, + VK_STRUCTURE_TYPE_DEVICE_GROUP_BIND_SPARSE_INFO_KHX = 1000060006, + VK_STRUCTURE_TYPE_DEVICE_GROUP_PRESENT_CAPABILITIES_KHX = 1000060007, + VK_STRUCTURE_TYPE_IMAGE_SWAPCHAIN_CREATE_INFO_KHX = 1000060008, + VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_SWAPCHAIN_INFO_KHX = 1000060009, + VK_STRUCTURE_TYPE_ACQUIRE_NEXT_IMAGE_INFO_KHX = 1000060010, + VK_STRUCTURE_TYPE_DEVICE_GROUP_PRESENT_INFO_KHX = 1000060011, + VK_STRUCTURE_TYPE_DEVICE_GROUP_SWAPCHAIN_CREATE_INFO_KHX = 1000060012, VK_STRUCTURE_TYPE_VALIDATION_FLAGS_EXT = 1000061000, + VK_STRUCTURE_TYPE_VI_SURFACE_CREATE_INFO_NN = 1000062000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_GROUP_PROPERTIES_KHX = 1000070000, + VK_STRUCTURE_TYPE_DEVICE_GROUP_DEVICE_CREATE_INFO_KHX = 1000070001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_IMAGE_FORMAT_INFO_KHR = 1000071000, + VK_STRUCTURE_TYPE_EXTERNAL_IMAGE_FORMAT_PROPERTIES_KHR = 1000071001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_BUFFER_INFO_KHR = 1000071002, + VK_STRUCTURE_TYPE_EXTERNAL_BUFFER_PROPERTIES_KHR = 1000071003, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES_KHR = 1000071004, + VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_BUFFER_CREATE_INFO_KHR = 1000072000, + VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO_KHR = 1000072001, + VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR = 1000072002, + VK_STRUCTURE_TYPE_IMPORT_MEMORY_WIN32_HANDLE_INFO_KHR = 1000073000, + VK_STRUCTURE_TYPE_EXPORT_MEMORY_WIN32_HANDLE_INFO_KHR = 1000073001, + VK_STRUCTURE_TYPE_MEMORY_WIN32_HANDLE_PROPERTIES_KHR = 1000073002, + VK_STRUCTURE_TYPE_MEMORY_GET_WIN32_HANDLE_INFO_KHR = 1000073003, + VK_STRUCTURE_TYPE_IMPORT_MEMORY_FD_INFO_KHR = 1000074000, + VK_STRUCTURE_TYPE_MEMORY_FD_PROPERTIES_KHR = 1000074001, + VK_STRUCTURE_TYPE_MEMORY_GET_FD_INFO_KHR = 1000074002, + VK_STRUCTURE_TYPE_WIN32_KEYED_MUTEX_ACQUIRE_RELEASE_INFO_KHR = 1000075000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_SEMAPHORE_INFO_KHR = 1000076000, + VK_STRUCTURE_TYPE_EXTERNAL_SEMAPHORE_PROPERTIES_KHR = 1000076001, + VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_CREATE_INFO_KHR = 1000077000, + VK_STRUCTURE_TYPE_IMPORT_SEMAPHORE_WIN32_HANDLE_INFO_KHR = 1000078000, + VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_WIN32_HANDLE_INFO_KHR = 1000078001, + VK_STRUCTURE_TYPE_D3D12_FENCE_SUBMIT_INFO_KHR = 1000078002, + VK_STRUCTURE_TYPE_SEMAPHORE_GET_WIN32_HANDLE_INFO_KHR = 1000078003, + VK_STRUCTURE_TYPE_IMPORT_SEMAPHORE_FD_INFO_KHR = 1000079000, + VK_STRUCTURE_TYPE_SEMAPHORE_GET_FD_INFO_KHR = 1000079001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PUSH_DESCRIPTOR_PROPERTIES_KHR = 1000080000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES_KHR = 1000083000, + VK_STRUCTURE_TYPE_PRESENT_REGIONS_KHR = 1000084000, + VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR = 1000085000, + VK_STRUCTURE_TYPE_OBJECT_TABLE_CREATE_INFO_NVX = 1000086000, + VK_STRUCTURE_TYPE_INDIRECT_COMMANDS_LAYOUT_CREATE_INFO_NVX = 1000086001, + VK_STRUCTURE_TYPE_CMD_PROCESS_COMMANDS_INFO_NVX = 1000086002, + VK_STRUCTURE_TYPE_CMD_RESERVE_SPACE_FOR_COMMANDS_INFO_NVX = 1000086003, + VK_STRUCTURE_TYPE_DEVICE_GENERATED_COMMANDS_LIMITS_NVX = 1000086004, + VK_STRUCTURE_TYPE_DEVICE_GENERATED_COMMANDS_FEATURES_NVX = 1000086005, + VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_W_SCALING_STATE_CREATE_INFO_NV = 1000087000, + VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES_2_EXT = 1000090000, + VK_STRUCTURE_TYPE_DISPLAY_POWER_INFO_EXT = 1000091000, + VK_STRUCTURE_TYPE_DEVICE_EVENT_INFO_EXT = 1000091001, + VK_STRUCTURE_TYPE_DISPLAY_EVENT_INFO_EXT = 1000091002, + VK_STRUCTURE_TYPE_SWAPCHAIN_COUNTER_CREATE_INFO_EXT = 1000091003, + VK_STRUCTURE_TYPE_PRESENT_TIMES_INFO_GOOGLE = 1000092000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MULTIVIEW_PER_VIEW_ATTRIBUTES_PROPERTIES_NVX = 1000097000, + VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_SWIZZLE_STATE_CREATE_INFO_NV = 1000098000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DISCARD_RECTANGLE_PROPERTIES_EXT = 1000099000, + VK_STRUCTURE_TYPE_PIPELINE_DISCARD_RECTANGLE_STATE_CREATE_INFO_EXT = 1000099001, + VK_STRUCTURE_TYPE_HDR_METADATA_EXT = 1000105000, + VK_STRUCTURE_TYPE_SHARED_PRESENT_SURFACE_CAPABILITIES_KHR = 1000111000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_FENCE_INFO_KHR = 1000112000, + VK_STRUCTURE_TYPE_EXTERNAL_FENCE_PROPERTIES_KHR = 1000112001, + VK_STRUCTURE_TYPE_EXPORT_FENCE_CREATE_INFO_KHR = 1000113000, + VK_STRUCTURE_TYPE_IMPORT_FENCE_WIN32_HANDLE_INFO_KHR = 1000114000, + VK_STRUCTURE_TYPE_EXPORT_FENCE_WIN32_HANDLE_INFO_KHR = 1000114001, + VK_STRUCTURE_TYPE_FENCE_GET_WIN32_HANDLE_INFO_KHR = 1000114002, + VK_STRUCTURE_TYPE_IMPORT_FENCE_FD_INFO_KHR = 1000115000, + VK_STRUCTURE_TYPE_FENCE_GET_FD_INFO_KHR = 1000115001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SURFACE_INFO_2_KHR = 1000119000, + VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES_2_KHR = 1000119001, + VK_STRUCTURE_TYPE_SURFACE_FORMAT_2_KHR = 1000119002, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTER_FEATURES_KHR = 1000120000, + VK_STRUCTURE_TYPE_IOS_SURFACE_CREATE_INFO_MVK = 1000122000, + VK_STRUCTURE_TYPE_MACOS_SURFACE_CREATE_INFO_MVK = 1000123000, + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR = 1000127000, + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR = 1000127001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SAMPLER_FILTER_MINMAX_PROPERTIES_EXT = 1000130000, + VK_STRUCTURE_TYPE_SAMPLER_REDUCTION_MODE_CREATE_INFO_EXT = 1000130001, + VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR = 1000146000, + VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2_KHR = 1000146001, + VK_STRUCTURE_TYPE_IMAGE_SPARSE_MEMORY_REQUIREMENTS_INFO_2_KHR = 1000146002, + VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR = 1000146003, + VK_STRUCTURE_TYPE_SPARSE_IMAGE_MEMORY_REQUIREMENTS_2_KHR = 1000146004, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BLEND_OPERATION_ADVANCED_FEATURES_EXT = 1000148000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BLEND_OPERATION_ADVANCED_PROPERTIES_EXT = 1000148001, + VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_ADVANCED_STATE_CREATE_INFO_EXT = 1000148002, + VK_STRUCTURE_TYPE_PIPELINE_COVERAGE_TO_COLOR_STATE_CREATE_INFO_NV = 1000149000, + VK_STRUCTURE_TYPE_PIPELINE_COVERAGE_MODULATION_STATE_CREATE_INFO_NV = 1000152000, VK_STRUCTURE_TYPE_BEGIN_RANGE = VK_STRUCTURE_TYPE_APPLICATION_INFO, VK_STRUCTURE_TYPE_END_RANGE = VK_STRUCTURE_TYPE_LOADER_DEVICE_CREATE_INFO, VK_STRUCTURE_TYPE_RANGE_SIZE = (VK_STRUCTURE_TYPE_LOADER_DEVICE_CREATE_INFO - VK_STRUCTURE_TYPE_APPLICATION_INFO + 1), @@ -513,6 +620,7 @@ typedef enum VkImageLayout { VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL = 7, VK_IMAGE_LAYOUT_PREINITIALIZED = 8, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR = 1000001002, + VK_IMAGE_LAYOUT_SHARED_PRESENT_KHR = 1000111000, VK_IMAGE_LAYOUT_BEGIN_RANGE = VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_END_RANGE = VK_IMAGE_LAYOUT_PREINITIALIZED, VK_IMAGE_LAYOUT_RANGE_SIZE = (VK_IMAGE_LAYOUT_PREINITIALIZED - VK_IMAGE_LAYOUT_UNDEFINED + 1), @@ -578,6 +686,7 @@ typedef enum VkPolygonMode { VK_POLYGON_MODE_FILL = 0, VK_POLYGON_MODE_LINE = 1, VK_POLYGON_MODE_POINT = 2, + VK_POLYGON_MODE_FILL_RECTANGLE_NV = 1000153000, VK_POLYGON_MODE_BEGIN_RANGE = VK_POLYGON_MODE_FILL, VK_POLYGON_MODE_END_RANGE = VK_POLYGON_MODE_POINT, VK_POLYGON_MODE_RANGE_SIZE = (VK_POLYGON_MODE_POINT - VK_POLYGON_MODE_FILL + 1), @@ -678,6 +787,52 @@ typedef enum VkBlendOp { VK_BLEND_OP_REVERSE_SUBTRACT = 2, VK_BLEND_OP_MIN = 3, VK_BLEND_OP_MAX = 4, + VK_BLEND_OP_ZERO_EXT = 1000148000, + VK_BLEND_OP_SRC_EXT = 1000148001, + VK_BLEND_OP_DST_EXT = 1000148002, + VK_BLEND_OP_SRC_OVER_EXT = 1000148003, + VK_BLEND_OP_DST_OVER_EXT = 1000148004, + VK_BLEND_OP_SRC_IN_EXT = 1000148005, + VK_BLEND_OP_DST_IN_EXT = 1000148006, + VK_BLEND_OP_SRC_OUT_EXT = 1000148007, + VK_BLEND_OP_DST_OUT_EXT = 1000148008, + VK_BLEND_OP_SRC_ATOP_EXT = 1000148009, + VK_BLEND_OP_DST_ATOP_EXT = 1000148010, + VK_BLEND_OP_XOR_EXT = 1000148011, + VK_BLEND_OP_MULTIPLY_EXT = 1000148012, + VK_BLEND_OP_SCREEN_EXT = 1000148013, + VK_BLEND_OP_OVERLAY_EXT = 1000148014, + VK_BLEND_OP_DARKEN_EXT = 1000148015, + VK_BLEND_OP_LIGHTEN_EXT = 1000148016, + VK_BLEND_OP_COLORDODGE_EXT = 1000148017, + VK_BLEND_OP_COLORBURN_EXT = 1000148018, + VK_BLEND_OP_HARDLIGHT_EXT = 1000148019, + VK_BLEND_OP_SOFTLIGHT_EXT = 1000148020, + VK_BLEND_OP_DIFFERENCE_EXT = 1000148021, + VK_BLEND_OP_EXCLUSION_EXT = 1000148022, + VK_BLEND_OP_INVERT_EXT = 1000148023, + VK_BLEND_OP_INVERT_RGB_EXT = 1000148024, + VK_BLEND_OP_LINEARDODGE_EXT = 1000148025, + VK_BLEND_OP_LINEARBURN_EXT = 1000148026, + VK_BLEND_OP_VIVIDLIGHT_EXT = 1000148027, + VK_BLEND_OP_LINEARLIGHT_EXT = 1000148028, + VK_BLEND_OP_PINLIGHT_EXT = 1000148029, + VK_BLEND_OP_HARDMIX_EXT = 1000148030, + VK_BLEND_OP_HSL_HUE_EXT = 1000148031, + VK_BLEND_OP_HSL_SATURATION_EXT = 1000148032, + VK_BLEND_OP_HSL_COLOR_EXT = 1000148033, + VK_BLEND_OP_HSL_LUMINOSITY_EXT = 1000148034, + VK_BLEND_OP_PLUS_EXT = 1000148035, + VK_BLEND_OP_PLUS_CLAMPED_EXT = 1000148036, + VK_BLEND_OP_PLUS_CLAMPED_ALPHA_EXT = 1000148037, + VK_BLEND_OP_PLUS_DARKER_EXT = 1000148038, + VK_BLEND_OP_MINUS_EXT = 1000148039, + VK_BLEND_OP_MINUS_CLAMPED_EXT = 1000148040, + VK_BLEND_OP_CONTRAST_EXT = 1000148041, + VK_BLEND_OP_INVERT_OVG_EXT = 1000148042, + VK_BLEND_OP_RED_EXT = 1000148043, + VK_BLEND_OP_GREEN_EXT = 1000148044, + VK_BLEND_OP_BLUE_EXT = 1000148045, VK_BLEND_OP_BEGIN_RANGE = VK_BLEND_OP_ADD, VK_BLEND_OP_END_RANGE = VK_BLEND_OP_MAX, VK_BLEND_OP_RANGE_SIZE = (VK_BLEND_OP_MAX - VK_BLEND_OP_ADD + 1), @@ -694,6 +849,8 @@ typedef enum VkDynamicState { VK_DYNAMIC_STATE_STENCIL_COMPARE_MASK = 6, VK_DYNAMIC_STATE_STENCIL_WRITE_MASK = 7, VK_DYNAMIC_STATE_STENCIL_REFERENCE = 8, + VK_DYNAMIC_STATE_VIEWPORT_W_SCALING_NV = 1000087000, + VK_DYNAMIC_STATE_DISCARD_RECTANGLE_EXT = 1000099000, VK_DYNAMIC_STATE_BEGIN_RANGE = VK_DYNAMIC_STATE_VIEWPORT, VK_DYNAMIC_STATE_END_RANGE = VK_DYNAMIC_STATE_STENCIL_REFERENCE, VK_DYNAMIC_STATE_RANGE_SIZE = (VK_DYNAMIC_STATE_STENCIL_REFERENCE - VK_DYNAMIC_STATE_VIEWPORT + 1), @@ -817,6 +974,47 @@ typedef enum VkSubpassContents { VK_SUBPASS_CONTENTS_MAX_ENUM = 0x7FFFFFFF } VkSubpassContents; +typedef enum VkObjectType { + VK_OBJECT_TYPE_UNKNOWN = 0, + VK_OBJECT_TYPE_INSTANCE = 1, + VK_OBJECT_TYPE_PHYSICAL_DEVICE = 2, + VK_OBJECT_TYPE_DEVICE = 3, + VK_OBJECT_TYPE_QUEUE = 4, + VK_OBJECT_TYPE_SEMAPHORE = 5, + VK_OBJECT_TYPE_COMMAND_BUFFER = 6, + VK_OBJECT_TYPE_FENCE = 7, + VK_OBJECT_TYPE_DEVICE_MEMORY = 8, + VK_OBJECT_TYPE_BUFFER = 9, + VK_OBJECT_TYPE_IMAGE = 10, + VK_OBJECT_TYPE_EVENT = 11, + VK_OBJECT_TYPE_QUERY_POOL = 12, + VK_OBJECT_TYPE_BUFFER_VIEW = 13, + VK_OBJECT_TYPE_IMAGE_VIEW = 14, + VK_OBJECT_TYPE_SHADER_MODULE = 15, + VK_OBJECT_TYPE_PIPELINE_CACHE = 16, + VK_OBJECT_TYPE_PIPELINE_LAYOUT = 17, + VK_OBJECT_TYPE_RENDER_PASS = 18, + VK_OBJECT_TYPE_PIPELINE = 19, + VK_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT = 20, + VK_OBJECT_TYPE_SAMPLER = 21, + VK_OBJECT_TYPE_DESCRIPTOR_POOL = 22, + VK_OBJECT_TYPE_DESCRIPTOR_SET = 23, + VK_OBJECT_TYPE_FRAMEBUFFER = 24, + VK_OBJECT_TYPE_COMMAND_POOL = 25, + VK_OBJECT_TYPE_SURFACE_KHR = 1000000000, + VK_OBJECT_TYPE_SWAPCHAIN_KHR = 1000001000, + VK_OBJECT_TYPE_DISPLAY_KHR = 1000002000, + VK_OBJECT_TYPE_DISPLAY_MODE_KHR = 1000002001, + VK_OBJECT_TYPE_DEBUG_REPORT_CALLBACK_EXT = 1000011000, + VK_OBJECT_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_KHR = 1000085000, + VK_OBJECT_TYPE_OBJECT_TABLE_NVX = 1000086000, + VK_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX = 1000086001, + VK_OBJECT_TYPE_BEGIN_RANGE = VK_OBJECT_TYPE_UNKNOWN, + VK_OBJECT_TYPE_END_RANGE = VK_OBJECT_TYPE_COMMAND_POOL, + VK_OBJECT_TYPE_RANGE_SIZE = (VK_OBJECT_TYPE_COMMAND_POOL - VK_OBJECT_TYPE_UNKNOWN + 1), + VK_OBJECT_TYPE_MAX_ENUM = 0x7FFFFFFF +} VkObjectType; + typedef VkFlags VkInstanceCreateFlags; typedef enum VkFormatFeatureFlagBits { @@ -834,6 +1032,9 @@ typedef enum VkFormatFeatureFlagBits { VK_FORMAT_FEATURE_BLIT_DST_BIT = 0x00000800, VK_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_LINEAR_BIT = 0x00001000, VK_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_CUBIC_BIT_IMG = 0x00002000, + VK_FORMAT_FEATURE_TRANSFER_SRC_BIT_KHR = 0x00004000, + VK_FORMAT_FEATURE_TRANSFER_DST_BIT_KHR = 0x00008000, + VK_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_MINMAX_BIT_EXT = 0x00010000, VK_FORMAT_FEATURE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkFormatFeatureFlagBits; typedef VkFlags VkFormatFeatureFlags; @@ -857,6 +1058,8 @@ typedef enum VkImageCreateFlagBits { VK_IMAGE_CREATE_SPARSE_ALIASED_BIT = 0x00000004, VK_IMAGE_CREATE_MUTABLE_FORMAT_BIT = 0x00000008, VK_IMAGE_CREATE_CUBE_COMPATIBLE_BIT = 0x00000010, + VK_IMAGE_CREATE_BIND_SFR_BIT_KHX = 0x00000040, + VK_IMAGE_CREATE_2D_ARRAY_COMPATIBLE_BIT_KHR = 0x00000020, VK_IMAGE_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkImageCreateFlagBits; typedef VkFlags VkImageCreateFlags; @@ -894,6 +1097,7 @@ typedef VkFlags VkMemoryPropertyFlags; typedef enum VkMemoryHeapFlagBits { VK_MEMORY_HEAP_DEVICE_LOCAL_BIT = 0x00000001, + VK_MEMORY_HEAP_MULTI_INSTANCE_BIT_KHX = 0x00000002, VK_MEMORY_HEAP_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkMemoryHeapFlagBits; typedef VkFlags VkMemoryHeapFlags; @@ -918,6 +1122,7 @@ typedef enum VkPipelineStageFlagBits { VK_PIPELINE_STAGE_HOST_BIT = 0x00004000, VK_PIPELINE_STAGE_ALL_GRAPHICS_BIT = 0x00008000, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT = 0x00010000, + VK_PIPELINE_STAGE_COMMAND_PROCESS_BIT_NVX = 0x00020000, VK_PIPELINE_STAGE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkPipelineStageFlagBits; typedef VkFlags VkPipelineStageFlags; @@ -1010,6 +1215,8 @@ typedef enum VkPipelineCreateFlagBits { VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT = 0x00000001, VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT = 0x00000002, VK_PIPELINE_CREATE_DERIVATIVE_BIT = 0x00000004, + VK_PIPELINE_CREATE_VIEW_INDEX_FROM_DEVICE_INDEX_BIT_KHX = 0x00000008, + VK_PIPELINE_CREATE_DISPATCH_BASE_KHX = 0x00000010, VK_PIPELINE_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkPipelineCreateFlagBits; typedef VkFlags VkPipelineCreateFlags; @@ -1056,6 +1263,11 @@ typedef VkFlags VkPipelineDynamicStateCreateFlags; typedef VkFlags VkPipelineLayoutCreateFlags; typedef VkFlags VkShaderStageFlags; typedef VkFlags VkSamplerCreateFlags; + +typedef enum VkDescriptorSetLayoutCreateFlagBits { + VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR = 0x00000001, + VK_DESCRIPTOR_SET_LAYOUT_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VkDescriptorSetLayoutCreateFlagBits; typedef VkFlags VkDescriptorSetLayoutCreateFlags; typedef enum VkDescriptorPoolCreateFlagBits { @@ -1072,6 +1284,12 @@ typedef enum VkAttachmentDescriptionFlagBits { VK_ATTACHMENT_DESCRIPTION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkAttachmentDescriptionFlagBits; typedef VkFlags VkAttachmentDescriptionFlags; + +typedef enum VkSubpassDescriptionFlagBits { + VK_SUBPASS_DESCRIPTION_PER_VIEW_ATTRIBUTES_BIT_NVX = 0x00000001, + VK_SUBPASS_DESCRIPTION_PER_VIEW_POSITION_X_ONLY_BIT_NVX = 0x00000002, + VK_SUBPASS_DESCRIPTION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VkSubpassDescriptionFlagBits; typedef VkFlags VkSubpassDescriptionFlags; typedef enum VkAccessFlagBits { @@ -1092,12 +1310,17 @@ typedef enum VkAccessFlagBits { VK_ACCESS_HOST_WRITE_BIT = 0x00004000, VK_ACCESS_MEMORY_READ_BIT = 0x00008000, VK_ACCESS_MEMORY_WRITE_BIT = 0x00010000, + VK_ACCESS_COMMAND_PROCESS_READ_BIT_NVX = 0x00020000, + VK_ACCESS_COMMAND_PROCESS_WRITE_BIT_NVX = 0x00040000, + VK_ACCESS_COLOR_ATTACHMENT_READ_NONCOHERENT_BIT_EXT = 0x00080000, VK_ACCESS_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkAccessFlagBits; typedef VkFlags VkAccessFlags; typedef enum VkDependencyFlagBits { VK_DEPENDENCY_BY_REGION_BIT = 0x00000001, + VK_DEPENDENCY_VIEW_LOCAL_BIT_KHX = 0x00000002, + VK_DEPENDENCY_DEVICE_GROUP_BIT_KHX = 0x00000004, VK_DEPENDENCY_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkDependencyFlagBits; typedef VkFlags VkDependencyFlags; @@ -1143,6 +1366,27 @@ typedef enum VkStencilFaceFlagBits { } VkStencilFaceFlagBits; typedef VkFlags VkStencilFaceFlags; +typedef struct VkApplicationInfo { + VkStructureType sType; + const void* pNext; + const char* pApplicationName; + uint32_t applicationVersion; + const char* pEngineName; + uint32_t engineVersion; + uint32_t apiVersion; +} VkApplicationInfo; + +typedef struct VkInstanceCreateInfo { + VkStructureType sType; + const void* pNext; + VkInstanceCreateFlags flags; + const VkApplicationInfo* pApplicationInfo; + uint32_t enabledLayerCount; + const char* const* ppEnabledLayerNames; + uint32_t enabledExtensionCount; + const char* const* ppEnabledExtensionNames; +} VkInstanceCreateInfo; + typedef void* (VKAPI_PTR *PFN_vkAllocationFunction)( void* pUserData, size_t size, @@ -1172,29 +1416,6 @@ typedef void (VKAPI_PTR *PFN_vkInternalFreeNotification)( VkInternalAllocationType allocationType, VkSystemAllocationScope allocationScope); -typedef void (VKAPI_PTR *PFN_vkVoidFunction)(void); - -typedef struct VkApplicationInfo { - VkStructureType sType; - const void* pNext; - const char* pApplicationName; - uint32_t applicationVersion; - const char* pEngineName; - uint32_t engineVersion; - uint32_t apiVersion; -} VkApplicationInfo; - -typedef struct VkInstanceCreateInfo { - VkStructureType sType; - const void* pNext; - VkInstanceCreateFlags flags; - const VkApplicationInfo* pApplicationInfo; - uint32_t enabledLayerCount; - const char* const* ppEnabledLayerNames; - uint32_t enabledExtensionCount; - const char* const* ppEnabledExtensionNames; -} VkInstanceCreateInfo; - typedef struct VkAllocationCallbacks { void* pUserData; PFN_vkAllocationFunction pfnAllocation; @@ -1435,6 +1656,7 @@ typedef struct VkPhysicalDeviceMemoryProperties { VkMemoryHeap memoryHeaps[VK_MAX_MEMORY_HEAPS]; } VkPhysicalDeviceMemoryProperties; +typedef void (VKAPI_PTR *PFN_vkVoidFunction)(void); typedef struct VkDeviceQueueCreateInfo { VkStructureType sType; const void* pNext; @@ -2360,7 +2582,7 @@ typedef void (VKAPI_PTR *PFN_vkCmdDraw)(VkCommandBuffer commandBuffer, uint32_t typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexed)(VkCommandBuffer commandBuffer, uint32_t indexCount, uint32_t instanceCount, uint32_t firstIndex, int32_t vertexOffset, uint32_t firstInstance); typedef void (VKAPI_PTR *PFN_vkCmdDrawIndirect)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, uint32_t drawCount, uint32_t stride); typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexedIndirect)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, uint32_t drawCount, uint32_t stride); -typedef void (VKAPI_PTR *PFN_vkCmdDispatch)(VkCommandBuffer commandBuffer, uint32_t x, uint32_t y, uint32_t z); +typedef void (VKAPI_PTR *PFN_vkCmdDispatch)(VkCommandBuffer commandBuffer, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ); typedef void (VKAPI_PTR *PFN_vkCmdDispatchIndirect)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset); typedef void (VKAPI_PTR *PFN_vkCmdCopyBuffer)(VkCommandBuffer commandBuffer, VkBuffer srcBuffer, VkBuffer dstBuffer, uint32_t regionCount, const VkBufferCopy* pRegions); typedef void (VKAPI_PTR *PFN_vkCmdCopyImage)(VkCommandBuffer commandBuffer, VkImage srcImage, VkImageLayout srcImageLayout, VkImage dstImage, VkImageLayout dstImageLayout, uint32_t regionCount, const VkImageCopy* pRegions); @@ -2996,9 +3218,9 @@ VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndexedIndirect( VKAPI_ATTR void VKAPI_CALL vkCmdDispatch( VkCommandBuffer commandBuffer, - uint32_t x, - uint32_t y, - uint32_t z); + uint32_t groupCountX, + uint32_t groupCountY, + uint32_t groupCountZ); VKAPI_ATTR void VKAPI_CALL vkCmdDispatchIndirect( VkCommandBuffer commandBuffer, @@ -3197,6 +3419,20 @@ VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkSurfaceKHR) typedef enum VkColorSpaceKHR { VK_COLOR_SPACE_SRGB_NONLINEAR_KHR = 0, + VK_COLOR_SPACE_DISPLAY_P3_NONLINEAR_EXT = 1000104001, + VK_COLOR_SPACE_EXTENDED_SRGB_LINEAR_EXT = 1000104002, + VK_COLOR_SPACE_DCI_P3_LINEAR_EXT = 1000104003, + VK_COLOR_SPACE_DCI_P3_NONLINEAR_EXT = 1000104004, + VK_COLOR_SPACE_BT709_LINEAR_EXT = 1000104005, + VK_COLOR_SPACE_BT709_NONLINEAR_EXT = 1000104006, + VK_COLOR_SPACE_BT2020_LINEAR_EXT = 1000104007, + VK_COLOR_SPACE_HDR10_ST2084_EXT = 1000104008, + VK_COLOR_SPACE_DOLBYVISION_EXT = 1000104009, + VK_COLOR_SPACE_HDR10_HLG_EXT = 1000104010, + VK_COLOR_SPACE_ADOBERGB_LINEAR_EXT = 1000104011, + VK_COLOR_SPACE_ADOBERGB_NONLINEAR_EXT = 1000104012, + VK_COLOR_SPACE_PASS_THROUGH_EXT = 1000104013, + VK_COLOR_SPACE_EXTENDED_SRGB_NONLINEAR_EXT = 1000104014, VK_COLOR_SPACE_BEGIN_RANGE_KHR = VK_COLOR_SPACE_SRGB_NONLINEAR_KHR, VK_COLOR_SPACE_END_RANGE_KHR = VK_COLOR_SPACE_SRGB_NONLINEAR_KHR, VK_COLOR_SPACE_RANGE_SIZE_KHR = (VK_COLOR_SPACE_SRGB_NONLINEAR_KHR - VK_COLOR_SPACE_SRGB_NONLINEAR_KHR + 1), @@ -3208,6 +3444,8 @@ typedef enum VkPresentModeKHR { VK_PRESENT_MODE_MAILBOX_KHR = 1, VK_PRESENT_MODE_FIFO_KHR = 2, VK_PRESENT_MODE_FIFO_RELAXED_KHR = 3, + VK_PRESENT_MODE_SHARED_DEMAND_REFRESH_KHR = 1000111000, + VK_PRESENT_MODE_SHARED_CONTINUOUS_REFRESH_KHR = 1000111001, VK_PRESENT_MODE_BEGIN_RANGE_KHR = VK_PRESENT_MODE_IMMEDIATE_KHR, VK_PRESENT_MODE_END_RANGE_KHR = VK_PRESENT_MODE_FIFO_RELAXED_KHR, VK_PRESENT_MODE_RANGE_SIZE_KHR = (VK_PRESENT_MODE_FIFO_RELAXED_KHR - VK_PRESENT_MODE_IMMEDIATE_KHR + 1), @@ -3299,6 +3537,11 @@ VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkSwapchainKHR) #define VK_KHR_SWAPCHAIN_SPEC_VERSION 68 #define VK_KHR_SWAPCHAIN_EXTENSION_NAME "VK_KHR_swapchain" + +typedef enum VkSwapchainCreateFlagBitsKHR { + VK_SWAPCHAIN_CREATE_BIND_SFR_BIT_KHX = 0x00000001, + VK_SWAPCHAIN_CREATE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkSwapchainCreateFlagBitsKHR; typedef VkFlags VkSwapchainCreateFlagsKHR; typedef struct VkSwapchainCreateInfoKHR { @@ -3599,7 +3842,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL vkGetPhysicalDeviceXcbPresentationSupportKHR( #define VK_KHR_wayland_surface 1 #include -#define VK_KHR_WAYLAND_SURFACE_SPEC_VERSION 5 +#define VK_KHR_WAYLAND_SURFACE_SPEC_VERSION 6 #define VK_KHR_WAYLAND_SURFACE_EXTENSION_NAME "VK_KHR_wayland_surface" typedef VkFlags VkWaylandSurfaceCreateFlagsKHR; @@ -3697,7 +3940,7 @@ VKAPI_ATTR VkResult VKAPI_CALL vkCreateAndroidSurfaceKHR( #define VK_KHR_win32_surface 1 #include -#define VK_KHR_WIN32_SURFACE_SPEC_VERSION 5 +#define VK_KHR_WIN32_SURFACE_SPEC_VERSION 6 #define VK_KHR_WIN32_SURFACE_EXTENSION_NAME "VK_KHR_win32_surface" typedef VkFlags VkWin32SurfaceCreateFlagsKHR; @@ -3732,426 +3975,2480 @@ VKAPI_ATTR VkBool32 VKAPI_CALL vkGetPhysicalDeviceWin32PresentationSupportKHR( #define VK_KHR_SAMPLER_MIRROR_CLAMP_TO_EDGE_EXTENSION_NAME "VK_KHR_sampler_mirror_clamp_to_edge" -#define VK_EXT_debug_report 1 -VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkDebugReportCallbackEXT) - -#define VK_EXT_DEBUG_REPORT_SPEC_VERSION 3 -#define VK_EXT_DEBUG_REPORT_EXTENSION_NAME "VK_EXT_debug_report" -#define VK_STRUCTURE_TYPE_DEBUG_REPORT_CREATE_INFO_EXT VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT +#define VK_KHR_get_physical_device_properties2 1 +#define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_SPEC_VERSION 1 +#define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME "VK_KHR_get_physical_device_properties2" +typedef struct VkPhysicalDeviceFeatures2KHR { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceFeatures features; +} VkPhysicalDeviceFeatures2KHR; -typedef enum VkDebugReportObjectTypeEXT { - VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT = 0, - VK_DEBUG_REPORT_OBJECT_TYPE_INSTANCE_EXT = 1, - VK_DEBUG_REPORT_OBJECT_TYPE_PHYSICAL_DEVICE_EXT = 2, - VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_EXT = 3, - VK_DEBUG_REPORT_OBJECT_TYPE_QUEUE_EXT = 4, - VK_DEBUG_REPORT_OBJECT_TYPE_SEMAPHORE_EXT = 5, - VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT = 6, - VK_DEBUG_REPORT_OBJECT_TYPE_FENCE_EXT = 7, - VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_MEMORY_EXT = 8, - VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_EXT = 9, - VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_EXT = 10, - VK_DEBUG_REPORT_OBJECT_TYPE_EVENT_EXT = 11, - VK_DEBUG_REPORT_OBJECT_TYPE_QUERY_POOL_EXT = 12, - VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_VIEW_EXT = 13, - VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_VIEW_EXT = 14, - VK_DEBUG_REPORT_OBJECT_TYPE_SHADER_MODULE_EXT = 15, - VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_CACHE_EXT = 16, - VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_LAYOUT_EXT = 17, - VK_DEBUG_REPORT_OBJECT_TYPE_RENDER_PASS_EXT = 18, - VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_EXT = 19, - VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT_EXT = 20, - VK_DEBUG_REPORT_OBJECT_TYPE_SAMPLER_EXT = 21, - VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_POOL_EXT = 22, - VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_EXT = 23, - VK_DEBUG_REPORT_OBJECT_TYPE_FRAMEBUFFER_EXT = 24, - VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_POOL_EXT = 25, - VK_DEBUG_REPORT_OBJECT_TYPE_SURFACE_KHR_EXT = 26, - VK_DEBUG_REPORT_OBJECT_TYPE_SWAPCHAIN_KHR_EXT = 27, - VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT = 28, - VK_DEBUG_REPORT_OBJECT_TYPE_BEGIN_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, - VK_DEBUG_REPORT_OBJECT_TYPE_END_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT, - VK_DEBUG_REPORT_OBJECT_TYPE_RANGE_SIZE_EXT = (VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT - VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT + 1), - VK_DEBUG_REPORT_OBJECT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF -} VkDebugReportObjectTypeEXT; +typedef struct VkPhysicalDeviceProperties2KHR { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceProperties properties; +} VkPhysicalDeviceProperties2KHR; -typedef enum VkDebugReportErrorEXT { - VK_DEBUG_REPORT_ERROR_NONE_EXT = 0, - VK_DEBUG_REPORT_ERROR_CALLBACK_REF_EXT = 1, - VK_DEBUG_REPORT_ERROR_BEGIN_RANGE_EXT = VK_DEBUG_REPORT_ERROR_NONE_EXT, - VK_DEBUG_REPORT_ERROR_END_RANGE_EXT = VK_DEBUG_REPORT_ERROR_CALLBACK_REF_EXT, - VK_DEBUG_REPORT_ERROR_RANGE_SIZE_EXT = (VK_DEBUG_REPORT_ERROR_CALLBACK_REF_EXT - VK_DEBUG_REPORT_ERROR_NONE_EXT + 1), - VK_DEBUG_REPORT_ERROR_MAX_ENUM_EXT = 0x7FFFFFFF -} VkDebugReportErrorEXT; +typedef struct VkFormatProperties2KHR { + VkStructureType sType; + void* pNext; + VkFormatProperties formatProperties; +} VkFormatProperties2KHR; +typedef struct VkImageFormatProperties2KHR { + VkStructureType sType; + void* pNext; + VkImageFormatProperties imageFormatProperties; +} VkImageFormatProperties2KHR; -typedef enum VkDebugReportFlagBitsEXT { - VK_DEBUG_REPORT_INFORMATION_BIT_EXT = 0x00000001, - VK_DEBUG_REPORT_WARNING_BIT_EXT = 0x00000002, - VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT = 0x00000004, - VK_DEBUG_REPORT_ERROR_BIT_EXT = 0x00000008, - VK_DEBUG_REPORT_DEBUG_BIT_EXT = 0x00000010, - VK_DEBUG_REPORT_FLAG_BITS_MAX_ENUM_EXT = 0x7FFFFFFF -} VkDebugReportFlagBitsEXT; -typedef VkFlags VkDebugReportFlagsEXT; +typedef struct VkPhysicalDeviceImageFormatInfo2KHR { + VkStructureType sType; + const void* pNext; + VkFormat format; + VkImageType type; + VkImageTiling tiling; + VkImageUsageFlags usage; + VkImageCreateFlags flags; +} VkPhysicalDeviceImageFormatInfo2KHR; + +typedef struct VkQueueFamilyProperties2KHR { + VkStructureType sType; + void* pNext; + VkQueueFamilyProperties queueFamilyProperties; +} VkQueueFamilyProperties2KHR; -typedef VkBool32 (VKAPI_PTR *PFN_vkDebugReportCallbackEXT)( - VkDebugReportFlagsEXT flags, - VkDebugReportObjectTypeEXT objectType, - uint64_t object, - size_t location, - int32_t messageCode, - const char* pLayerPrefix, - const char* pMessage, - void* pUserData); +typedef struct VkPhysicalDeviceMemoryProperties2KHR { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceMemoryProperties memoryProperties; +} VkPhysicalDeviceMemoryProperties2KHR; +typedef struct VkSparseImageFormatProperties2KHR { + VkStructureType sType; + void* pNext; + VkSparseImageFormatProperties properties; +} VkSparseImageFormatProperties2KHR; -typedef struct VkDebugReportCallbackCreateInfoEXT { - VkStructureType sType; - const void* pNext; - VkDebugReportFlagsEXT flags; - PFN_vkDebugReportCallbackEXT pfnCallback; - void* pUserData; -} VkDebugReportCallbackCreateInfoEXT; +typedef struct VkPhysicalDeviceSparseImageFormatInfo2KHR { + VkStructureType sType; + const void* pNext; + VkFormat format; + VkImageType type; + VkSampleCountFlagBits samples; + VkImageUsageFlags usage; + VkImageTiling tiling; +} VkPhysicalDeviceSparseImageFormatInfo2KHR; -typedef VkResult (VKAPI_PTR *PFN_vkCreateDebugReportCallbackEXT)(VkInstance instance, const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugReportCallbackEXT* pCallback); -typedef void (VKAPI_PTR *PFN_vkDestroyDebugReportCallbackEXT)(VkInstance instance, VkDebugReportCallbackEXT callback, const VkAllocationCallbacks* pAllocator); -typedef void (VKAPI_PTR *PFN_vkDebugReportMessageEXT)(VkInstance instance, VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objectType, uint64_t object, size_t location, int32_t messageCode, const char* pLayerPrefix, const char* pMessage); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceFeatures2KHR)(VkPhysicalDevice physicalDevice, VkPhysicalDeviceFeatures2KHR* pFeatures); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceProperties2KHR)(VkPhysicalDevice physicalDevice, VkPhysicalDeviceProperties2KHR* pProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceFormatProperties2KHR)(VkPhysicalDevice physicalDevice, VkFormat format, VkFormatProperties2KHR* pFormatProperties); +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceImageFormatProperties2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceImageFormatInfo2KHR* pImageFormatInfo, VkImageFormatProperties2KHR* pImageFormatProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceQueueFamilyProperties2KHR)(VkPhysicalDevice physicalDevice, uint32_t* pQueueFamilyPropertyCount, VkQueueFamilyProperties2KHR* pQueueFamilyProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceMemoryProperties2KHR)(VkPhysicalDevice physicalDevice, VkPhysicalDeviceMemoryProperties2KHR* pMemoryProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceSparseImageFormatProperties2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceSparseImageFormatInfo2KHR* pFormatInfo, uint32_t* pPropertyCount, VkSparseImageFormatProperties2KHR* pProperties); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkCreateDebugReportCallbackEXT( - VkInstance instance, - const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, - const VkAllocationCallbacks* pAllocator, - VkDebugReportCallbackEXT* pCallback); +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceFeatures2KHR( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceFeatures2KHR* pFeatures); -VKAPI_ATTR void VKAPI_CALL vkDestroyDebugReportCallbackEXT( - VkInstance instance, - VkDebugReportCallbackEXT callback, - const VkAllocationCallbacks* pAllocator); +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceProperties2KHR( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceProperties2KHR* pProperties); -VKAPI_ATTR void VKAPI_CALL vkDebugReportMessageEXT( - VkInstance instance, - VkDebugReportFlagsEXT flags, - VkDebugReportObjectTypeEXT objectType, - uint64_t object, - size_t location, - int32_t messageCode, - const char* pLayerPrefix, - const char* pMessage); +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceFormatProperties2KHR( + VkPhysicalDevice physicalDevice, + VkFormat format, + VkFormatProperties2KHR* pFormatProperties); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceImageFormatProperties2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceImageFormatInfo2KHR* pImageFormatInfo, + VkImageFormatProperties2KHR* pImageFormatProperties); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceQueueFamilyProperties2KHR( + VkPhysicalDevice physicalDevice, + uint32_t* pQueueFamilyPropertyCount, + VkQueueFamilyProperties2KHR* pQueueFamilyProperties); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceMemoryProperties2KHR( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceMemoryProperties2KHR* pMemoryProperties); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceSparseImageFormatProperties2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceSparseImageFormatInfo2KHR* pFormatInfo, + uint32_t* pPropertyCount, + VkSparseImageFormatProperties2KHR* pProperties); #endif -#define VK_NV_glsl_shader 1 -#define VK_NV_GLSL_SHADER_SPEC_VERSION 1 -#define VK_NV_GLSL_SHADER_EXTENSION_NAME "VK_NV_glsl_shader" +#define VK_KHR_shader_draw_parameters 1 +#define VK_KHR_SHADER_DRAW_PARAMETERS_SPEC_VERSION 1 +#define VK_KHR_SHADER_DRAW_PARAMETERS_EXTENSION_NAME "VK_KHR_shader_draw_parameters" -#define VK_IMG_filter_cubic 1 -#define VK_IMG_FILTER_CUBIC_SPEC_VERSION 1 -#define VK_IMG_FILTER_CUBIC_EXTENSION_NAME "VK_IMG_filter_cubic" +#define VK_KHR_maintenance1 1 +#define VK_KHR_MAINTENANCE1_SPEC_VERSION 1 +#define VK_KHR_MAINTENANCE1_EXTENSION_NAME "VK_KHR_maintenance1" +typedef VkFlags VkCommandPoolTrimFlagsKHR; -#define VK_AMD_rasterization_order 1 -#define VK_AMD_RASTERIZATION_ORDER_SPEC_VERSION 1 -#define VK_AMD_RASTERIZATION_ORDER_EXTENSION_NAME "VK_AMD_rasterization_order" +typedef void (VKAPI_PTR *PFN_vkTrimCommandPoolKHR)(VkDevice device, VkCommandPool commandPool, VkCommandPoolTrimFlagsKHR flags); +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkTrimCommandPoolKHR( + VkDevice device, + VkCommandPool commandPool, + VkCommandPoolTrimFlagsKHR flags); +#endif -typedef enum VkRasterizationOrderAMD { - VK_RASTERIZATION_ORDER_STRICT_AMD = 0, - VK_RASTERIZATION_ORDER_RELAXED_AMD = 1, - VK_RASTERIZATION_ORDER_BEGIN_RANGE_AMD = VK_RASTERIZATION_ORDER_STRICT_AMD, - VK_RASTERIZATION_ORDER_END_RANGE_AMD = VK_RASTERIZATION_ORDER_RELAXED_AMD, - VK_RASTERIZATION_ORDER_RANGE_SIZE_AMD = (VK_RASTERIZATION_ORDER_RELAXED_AMD - VK_RASTERIZATION_ORDER_STRICT_AMD + 1), - VK_RASTERIZATION_ORDER_MAX_ENUM_AMD = 0x7FFFFFFF -} VkRasterizationOrderAMD; +#define VK_KHR_external_memory_capabilities 1 +#define VK_LUID_SIZE_KHR 8 +#define VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME "VK_KHR_external_memory_capabilities" + + +typedef enum VkExternalMemoryHandleTypeFlagBitsKHR { + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR = 0x00000001, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR = 0x00000002, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_KHR = 0x00000004, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_TEXTURE_BIT_KHR = 0x00000008, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_TEXTURE_KMT_BIT_KHR = 0x00000010, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP_BIT_KHR = 0x00000020, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE_BIT_KHR = 0x00000040, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalMemoryHandleTypeFlagBitsKHR; +typedef VkFlags VkExternalMemoryHandleTypeFlagsKHR; + +typedef enum VkExternalMemoryFeatureFlagBitsKHR { + VK_EXTERNAL_MEMORY_FEATURE_DEDICATED_ONLY_BIT_KHR = 0x00000001, + VK_EXTERNAL_MEMORY_FEATURE_EXPORTABLE_BIT_KHR = 0x00000002, + VK_EXTERNAL_MEMORY_FEATURE_IMPORTABLE_BIT_KHR = 0x00000004, + VK_EXTERNAL_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalMemoryFeatureFlagBitsKHR; +typedef VkFlags VkExternalMemoryFeatureFlagsKHR; + +typedef struct VkExternalMemoryPropertiesKHR { + VkExternalMemoryFeatureFlagsKHR externalMemoryFeatures; + VkExternalMemoryHandleTypeFlagsKHR exportFromImportedHandleTypes; + VkExternalMemoryHandleTypeFlagsKHR compatibleHandleTypes; +} VkExternalMemoryPropertiesKHR; + +typedef struct VkPhysicalDeviceExternalImageFormatInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalImageFormatInfoKHR; -typedef struct VkPipelineRasterizationStateRasterizationOrderAMD { - VkStructureType sType; - const void* pNext; - VkRasterizationOrderAMD rasterizationOrder; -} VkPipelineRasterizationStateRasterizationOrderAMD; +typedef struct VkExternalImageFormatPropertiesKHR { + VkStructureType sType; + void* pNext; + VkExternalMemoryPropertiesKHR externalMemoryProperties; +} VkExternalImageFormatPropertiesKHR; + +typedef struct VkPhysicalDeviceExternalBufferInfoKHR { + VkStructureType sType; + const void* pNext; + VkBufferCreateFlags flags; + VkBufferUsageFlags usage; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalBufferInfoKHR; +typedef struct VkExternalBufferPropertiesKHR { + VkStructureType sType; + void* pNext; + VkExternalMemoryPropertiesKHR externalMemoryProperties; +} VkExternalBufferPropertiesKHR; +typedef struct VkPhysicalDeviceIDPropertiesKHR { + VkStructureType sType; + void* pNext; + uint8_t deviceUUID[VK_UUID_SIZE]; + uint8_t driverUUID[VK_UUID_SIZE]; + uint8_t deviceLUID[VK_LUID_SIZE_KHR]; + uint32_t deviceNodeMask; + VkBool32 deviceLUIDValid; +} VkPhysicalDeviceIDPropertiesKHR; -#define VK_AMD_shader_trinary_minmax 1 -#define VK_AMD_SHADER_TRINARY_MINMAX_SPEC_VERSION 1 -#define VK_AMD_SHADER_TRINARY_MINMAX_EXTENSION_NAME "VK_AMD_shader_trinary_minmax" +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalBufferPropertiesKHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceExternalBufferInfoKHR* pExternalBufferInfo, VkExternalBufferPropertiesKHR* pExternalBufferProperties); -#define VK_AMD_shader_explicit_vertex_parameter 1 -#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_SPEC_VERSION 1 -#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_EXTENSION_NAME "VK_AMD_shader_explicit_vertex_parameter" +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceExternalBufferPropertiesKHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceExternalBufferInfoKHR* pExternalBufferInfo, + VkExternalBufferPropertiesKHR* pExternalBufferProperties); +#endif +#define VK_KHR_external_memory 1 +#define VK_KHR_EXTERNAL_MEMORY_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME "VK_KHR_external_memory" +#define VK_QUEUE_FAMILY_EXTERNAL_KHR (~0U-1) -#define VK_EXT_debug_marker 1 -#define VK_EXT_DEBUG_MARKER_SPEC_VERSION 3 -#define VK_EXT_DEBUG_MARKER_EXTENSION_NAME "VK_EXT_debug_marker" +typedef struct VkExternalMemoryImageCreateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsKHR handleTypes; +} VkExternalMemoryImageCreateInfoKHR; -typedef struct VkDebugMarkerObjectNameInfoEXT { - VkStructureType sType; - const void* pNext; - VkDebugReportObjectTypeEXT objectType; - uint64_t object; - const char* pObjectName; -} VkDebugMarkerObjectNameInfoEXT; +typedef struct VkExternalMemoryBufferCreateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsKHR handleTypes; +} VkExternalMemoryBufferCreateInfoKHR; -typedef struct VkDebugMarkerObjectTagInfoEXT { +typedef struct VkExportMemoryAllocateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsKHR handleTypes; +} VkExportMemoryAllocateInfoKHR; + + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_external_memory_win32 1 +#define VK_KHR_EXTERNAL_MEMORY_WIN32_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME "VK_KHR_external_memory_win32" + +typedef struct VkImportMemoryWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; + HANDLE handle; + LPCWSTR name; +} VkImportMemoryWin32HandleInfoKHR; + +typedef struct VkExportMemoryWin32HandleInfoKHR { VkStructureType sType; const void* pNext; - VkDebugReportObjectTypeEXT objectType; - uint64_t object; - uint64_t tagName; - size_t tagSize; - const void* pTag; -} VkDebugMarkerObjectTagInfoEXT; + const SECURITY_ATTRIBUTES* pAttributes; + DWORD dwAccess; + LPCWSTR name; +} VkExportMemoryWin32HandleInfoKHR; -typedef struct VkDebugMarkerMarkerInfoEXT { +typedef struct VkMemoryWin32HandlePropertiesKHR { VkStructureType sType; - const void* pNext; - const char* pMarkerName; - float color[4]; -} VkDebugMarkerMarkerInfoEXT; + void* pNext; + uint32_t memoryTypeBits; +} VkMemoryWin32HandlePropertiesKHR; +typedef struct VkMemoryGetWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkDeviceMemory memory; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkMemoryGetWin32HandleInfoKHR; -typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectTagEXT)(VkDevice device, VkDebugMarkerObjectTagInfoEXT* pTagInfo); -typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectNameEXT)(VkDevice device, VkDebugMarkerObjectNameInfoEXT* pNameInfo); -typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerBeginEXT)(VkCommandBuffer commandBuffer, VkDebugMarkerMarkerInfoEXT* pMarkerInfo); -typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerEndEXT)(VkCommandBuffer commandBuffer); -typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerInsertEXT)(VkCommandBuffer commandBuffer, VkDebugMarkerMarkerInfoEXT* pMarkerInfo); + +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandleKHR)(VkDevice device, const VkMemoryGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandlePropertiesKHR)(VkDevice device, VkExternalMemoryHandleTypeFlagBitsKHR handleType, HANDLE handle, VkMemoryWin32HandlePropertiesKHR* pMemoryWin32HandleProperties); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectTagEXT( +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandleKHR( VkDevice device, - VkDebugMarkerObjectTagInfoEXT* pTagInfo); + const VkMemoryGetWin32HandleInfoKHR* pGetWin32HandleInfo, + HANDLE* pHandle); -VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectNameEXT( +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandlePropertiesKHR( VkDevice device, - VkDebugMarkerObjectNameInfoEXT* pNameInfo); + VkExternalMemoryHandleTypeFlagBitsKHR handleType, + HANDLE handle, + VkMemoryWin32HandlePropertiesKHR* pMemoryWin32HandleProperties); +#endif +#endif /* VK_USE_PLATFORM_WIN32_KHR */ -VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerBeginEXT( - VkCommandBuffer commandBuffer, - VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +#define VK_KHR_external_memory_fd 1 +#define VK_KHR_EXTERNAL_MEMORY_FD_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME "VK_KHR_external_memory_fd" -VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerEndEXT( - VkCommandBuffer commandBuffer); +typedef struct VkImportMemoryFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; + int fd; +} VkImportMemoryFdInfoKHR; -VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerInsertEXT( - VkCommandBuffer commandBuffer, - VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +typedef struct VkMemoryFdPropertiesKHR { + VkStructureType sType; + void* pNext; + uint32_t memoryTypeBits; +} VkMemoryFdPropertiesKHR; + +typedef struct VkMemoryGetFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkDeviceMemory memory; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkMemoryGetFdInfoKHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryFdKHR)(VkDevice device, const VkMemoryGetFdInfoKHR* pGetFdInfo, int* pFd); +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryFdPropertiesKHR)(VkDevice device, VkExternalMemoryHandleTypeFlagBitsKHR handleType, int fd, VkMemoryFdPropertiesKHR* pMemoryFdProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryFdKHR( + VkDevice device, + const VkMemoryGetFdInfoKHR* pGetFdInfo, + int* pFd); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryFdPropertiesKHR( + VkDevice device, + VkExternalMemoryHandleTypeFlagBitsKHR handleType, + int fd, + VkMemoryFdPropertiesKHR* pMemoryFdProperties); #endif -#define VK_AMD_gcn_shader 1 -#define VK_AMD_GCN_SHADER_SPEC_VERSION 1 -#define VK_AMD_GCN_SHADER_EXTENSION_NAME "VK_AMD_gcn_shader" +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_win32_keyed_mutex 1 +#define VK_KHR_WIN32_KEYED_MUTEX_SPEC_VERSION 1 +#define VK_KHR_WIN32_KEYED_MUTEX_EXTENSION_NAME "VK_KHR_win32_keyed_mutex" +typedef struct VkWin32KeyedMutexAcquireReleaseInfoKHR { + VkStructureType sType; + const void* pNext; + uint32_t acquireCount; + const VkDeviceMemory* pAcquireSyncs; + const uint64_t* pAcquireKeys; + const uint32_t* pAcquireTimeouts; + uint32_t releaseCount; + const VkDeviceMemory* pReleaseSyncs; + const uint64_t* pReleaseKeys; +} VkWin32KeyedMutexAcquireReleaseInfoKHR; -#define VK_NV_dedicated_allocation 1 -#define VK_NV_DEDICATED_ALLOCATION_SPEC_VERSION 1 -#define VK_NV_DEDICATED_ALLOCATION_EXTENSION_NAME "VK_NV_dedicated_allocation" -typedef struct VkDedicatedAllocationImageCreateInfoNV { - VkStructureType sType; - const void* pNext; - VkBool32 dedicatedAllocation; -} VkDedicatedAllocationImageCreateInfoNV; +#endif /* VK_USE_PLATFORM_WIN32_KHR */ -typedef struct VkDedicatedAllocationBufferCreateInfoNV { - VkStructureType sType; - const void* pNext; - VkBool32 dedicatedAllocation; -} VkDedicatedAllocationBufferCreateInfoNV; +#define VK_KHR_external_semaphore_capabilities 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_EXTENSION_NAME "VK_KHR_external_semaphore_capabilities" + + +typedef enum VkExternalSemaphoreHandleTypeFlagBitsKHR { + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD_BIT_KHR = 0x00000001, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR = 0x00000002, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_KHR = 0x00000004, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE_BIT_KHR = 0x00000008, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT_KHR = 0x00000010, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalSemaphoreHandleTypeFlagBitsKHR; +typedef VkFlags VkExternalSemaphoreHandleTypeFlagsKHR; + +typedef enum VkExternalSemaphoreFeatureFlagBitsKHR { + VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT_KHR = 0x00000001, + VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT_KHR = 0x00000002, + VK_EXTERNAL_SEMAPHORE_FEATURE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalSemaphoreFeatureFlagBitsKHR; +typedef VkFlags VkExternalSemaphoreFeatureFlagsKHR; + +typedef struct VkPhysicalDeviceExternalSemaphoreInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalSemaphoreInfoKHR; -typedef struct VkDedicatedAllocationMemoryAllocateInfoNV { +typedef struct VkExternalSemaphorePropertiesKHR { + VkStructureType sType; + void* pNext; + VkExternalSemaphoreHandleTypeFlagsKHR exportFromImportedHandleTypes; + VkExternalSemaphoreHandleTypeFlagsKHR compatibleHandleTypes; + VkExternalSemaphoreFeatureFlagsKHR externalSemaphoreFeatures; +} VkExternalSemaphorePropertiesKHR; + + +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalSemaphorePropertiesKHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceExternalSemaphoreInfoKHR* pExternalSemaphoreInfo, VkExternalSemaphorePropertiesKHR* pExternalSemaphoreProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceExternalSemaphorePropertiesKHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceExternalSemaphoreInfoKHR* pExternalSemaphoreInfo, + VkExternalSemaphorePropertiesKHR* pExternalSemaphoreProperties); +#endif + +#define VK_KHR_external_semaphore 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_EXTENSION_NAME "VK_KHR_external_semaphore" + + +typedef enum VkSemaphoreImportFlagBitsKHR { + VK_SEMAPHORE_IMPORT_TEMPORARY_BIT_KHR = 0x00000001, + VK_SEMAPHORE_IMPORT_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkSemaphoreImportFlagBitsKHR; +typedef VkFlags VkSemaphoreImportFlagsKHR; + +typedef struct VkExportSemaphoreCreateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalSemaphoreHandleTypeFlagsKHR handleTypes; +} VkExportSemaphoreCreateInfoKHR; + + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_external_semaphore_win32 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_WIN32_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_WIN32_EXTENSION_NAME "VK_KHR_external_semaphore_win32" + +typedef struct VkImportSemaphoreWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkSemaphoreImportFlagsKHR flags; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; + HANDLE handle; + LPCWSTR name; +} VkImportSemaphoreWin32HandleInfoKHR; + +typedef struct VkExportSemaphoreWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + const SECURITY_ATTRIBUTES* pAttributes; + DWORD dwAccess; + LPCWSTR name; +} VkExportSemaphoreWin32HandleInfoKHR; + +typedef struct VkD3D12FenceSubmitInfoKHR { VkStructureType sType; const void* pNext; - VkImage image; - VkBuffer buffer; -} VkDedicatedAllocationMemoryAllocateInfoNV; + uint32_t waitSemaphoreValuesCount; + const uint64_t* pWaitSemaphoreValues; + uint32_t signalSemaphoreValuesCount; + const uint64_t* pSignalSemaphoreValues; +} VkD3D12FenceSubmitInfoKHR; + +typedef struct VkSemaphoreGetWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; +} VkSemaphoreGetWin32HandleInfoKHR; +typedef VkResult (VKAPI_PTR *PFN_vkImportSemaphoreWin32HandleKHR)(VkDevice device, const VkImportSemaphoreWin32HandleInfoKHR* pImportSemaphoreWin32HandleInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetSemaphoreWin32HandleKHR)(VkDevice device, const VkSemaphoreGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); -#define VK_AMD_draw_indirect_count 1 -#define VK_AMD_DRAW_INDIRECT_COUNT_SPEC_VERSION 1 -#define VK_AMD_DRAW_INDIRECT_COUNT_EXTENSION_NAME "VK_AMD_draw_indirect_count" +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkImportSemaphoreWin32HandleKHR( + VkDevice device, + const VkImportSemaphoreWin32HandleInfoKHR* pImportSemaphoreWin32HandleInfo); -typedef void (VKAPI_PTR *PFN_vkCmdDrawIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); -typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexedIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); +VKAPI_ATTR VkResult VKAPI_CALL vkGetSemaphoreWin32HandleKHR( + VkDevice device, + const VkSemaphoreGetWin32HandleInfoKHR* pGetWin32HandleInfo, + HANDLE* pHandle); +#endif +#endif /* VK_USE_PLATFORM_WIN32_KHR */ + +#define VK_KHR_external_semaphore_fd 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_FD_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME "VK_KHR_external_semaphore_fd" + +typedef struct VkImportSemaphoreFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkSemaphoreImportFlagsKHR flags; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; + int fd; +} VkImportSemaphoreFdInfoKHR; + +typedef struct VkSemaphoreGetFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; +} VkSemaphoreGetFdInfoKHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkImportSemaphoreFdKHR)(VkDevice device, const VkImportSemaphoreFdInfoKHR* pImportSemaphoreFdInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetSemaphoreFdKHR)(VkDevice device, const VkSemaphoreGetFdInfoKHR* pGetFdInfo, int* pFd); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndirectCountAMD( - VkCommandBuffer commandBuffer, - VkBuffer buffer, - VkDeviceSize offset, - VkBuffer countBuffer, - VkDeviceSize countBufferOffset, - uint32_t maxDrawCount, - uint32_t stride); +VKAPI_ATTR VkResult VKAPI_CALL vkImportSemaphoreFdKHR( + VkDevice device, + const VkImportSemaphoreFdInfoKHR* pImportSemaphoreFdInfo); -VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndexedIndirectCountAMD( +VKAPI_ATTR VkResult VKAPI_CALL vkGetSemaphoreFdKHR( + VkDevice device, + const VkSemaphoreGetFdInfoKHR* pGetFdInfo, + int* pFd); +#endif + +#define VK_KHR_push_descriptor 1 +#define VK_KHR_PUSH_DESCRIPTOR_SPEC_VERSION 1 +#define VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME "VK_KHR_push_descriptor" + +typedef struct VkPhysicalDevicePushDescriptorPropertiesKHR { + VkStructureType sType; + void* pNext; + uint32_t maxPushDescriptors; +} VkPhysicalDevicePushDescriptorPropertiesKHR; + + +typedef void (VKAPI_PTR *PFN_vkCmdPushDescriptorSetKHR)(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipelineBindPoint, VkPipelineLayout layout, uint32_t set, uint32_t descriptorWriteCount, const VkWriteDescriptorSet* pDescriptorWrites); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdPushDescriptorSetKHR( VkCommandBuffer commandBuffer, - VkBuffer buffer, - VkDeviceSize offset, - VkBuffer countBuffer, - VkDeviceSize countBufferOffset, - uint32_t maxDrawCount, - uint32_t stride); + VkPipelineBindPoint pipelineBindPoint, + VkPipelineLayout layout, + uint32_t set, + uint32_t descriptorWriteCount, + const VkWriteDescriptorSet* pDescriptorWrites); #endif -#define VK_AMD_negative_viewport_height 1 -#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_SPEC_VERSION 0 -#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_EXTENSION_NAME "VK_AMD_negative_viewport_height" +#define VK_KHR_16bit_storage 1 +#define VK_KHR_16BIT_STORAGE_SPEC_VERSION 1 +#define VK_KHR_16BIT_STORAGE_EXTENSION_NAME "VK_KHR_16bit_storage" +typedef struct VkPhysicalDevice16BitStorageFeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 storageBuffer16BitAccess; + VkBool32 uniformAndStorageBuffer16BitAccess; + VkBool32 storagePushConstant16; + VkBool32 storageInputOutput16; +} VkPhysicalDevice16BitStorageFeaturesKHR; -#define VK_AMD_gpu_shader_half_float 1 -#define VK_AMD_GPU_SHADER_HALF_FLOAT_SPEC_VERSION 1 -#define VK_AMD_GPU_SHADER_HALF_FLOAT_EXTENSION_NAME "VK_AMD_gpu_shader_half_float" -#define VK_AMD_shader_ballot 1 -#define VK_AMD_SHADER_BALLOT_SPEC_VERSION 0 -#define VK_AMD_SHADER_BALLOT_EXTENSION_NAME "VK_AMD_shader_ballot" +#define VK_KHR_incremental_present 1 +#define VK_KHR_INCREMENTAL_PRESENT_SPEC_VERSION 1 +#define VK_KHR_INCREMENTAL_PRESENT_EXTENSION_NAME "VK_KHR_incremental_present" +typedef struct VkRectLayerKHR { + VkOffset2D offset; + VkExtent2D extent; + uint32_t layer; +} VkRectLayerKHR; -#define VK_IMG_format_pvrtc 1 -#define VK_IMG_FORMAT_PVRTC_SPEC_VERSION 1 -#define VK_IMG_FORMAT_PVRTC_EXTENSION_NAME "VK_IMG_format_pvrtc" +typedef struct VkPresentRegionKHR { + uint32_t rectangleCount; + const VkRectLayerKHR* pRectangles; +} VkPresentRegionKHR; +typedef struct VkPresentRegionsKHR { + VkStructureType sType; + const void* pNext; + uint32_t swapchainCount; + const VkPresentRegionKHR* pRegions; +} VkPresentRegionsKHR; -#define VK_NV_external_memory_capabilities 1 -#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_SPEC_VERSION 1 -#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME "VK_NV_external_memory_capabilities" -typedef enum VkExternalMemoryHandleTypeFlagBitsNV { - VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_NV = 0x00000001, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_NV = 0x00000002, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_BIT_NV = 0x00000004, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_KMT_BIT_NV = 0x00000008, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF -} VkExternalMemoryHandleTypeFlagBitsNV; -typedef VkFlags VkExternalMemoryHandleTypeFlagsNV; +#define VK_KHR_descriptor_update_template 1 +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkDescriptorUpdateTemplateKHR) -typedef enum VkExternalMemoryFeatureFlagBitsNV { - VK_EXTERNAL_MEMORY_FEATURE_DEDICATED_ONLY_BIT_NV = 0x00000001, - VK_EXTERNAL_MEMORY_FEATURE_EXPORTABLE_BIT_NV = 0x00000002, - VK_EXTERNAL_MEMORY_FEATURE_IMPORTABLE_BIT_NV = 0x00000004, - VK_EXTERNAL_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF -} VkExternalMemoryFeatureFlagBitsNV; -typedef VkFlags VkExternalMemoryFeatureFlagsNV; +#define VK_KHR_DESCRIPTOR_UPDATE_TEMPLATE_SPEC_VERSION 1 +#define VK_KHR_DESCRIPTOR_UPDATE_TEMPLATE_EXTENSION_NAME "VK_KHR_descriptor_update_template" -typedef struct VkExternalImageFormatPropertiesNV { - VkImageFormatProperties imageFormatProperties; - VkExternalMemoryFeatureFlagsNV externalMemoryFeatures; - VkExternalMemoryHandleTypeFlagsNV exportFromImportedHandleTypes; - VkExternalMemoryHandleTypeFlagsNV compatibleHandleTypes; -} VkExternalImageFormatPropertiesNV; +typedef enum VkDescriptorUpdateTemplateTypeKHR { + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET_KHR = 0, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR = 1, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_BEGIN_RANGE_KHR = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET_KHR, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_END_RANGE_KHR = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_RANGE_SIZE_KHR = (VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR - VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET_KHR + 1), + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_MAX_ENUM_KHR = 0x7FFFFFFF +} VkDescriptorUpdateTemplateTypeKHR; -typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalImageFormatPropertiesNV)(VkPhysicalDevice physicalDevice, VkFormat format, VkImageType type, VkImageTiling tiling, VkImageUsageFlags usage, VkImageCreateFlags flags, VkExternalMemoryHandleTypeFlagsNV externalHandleType, VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); +typedef VkFlags VkDescriptorUpdateTemplateCreateFlagsKHR; + +typedef struct VkDescriptorUpdateTemplateEntryKHR { + uint32_t dstBinding; + uint32_t dstArrayElement; + uint32_t descriptorCount; + VkDescriptorType descriptorType; + size_t offset; + size_t stride; +} VkDescriptorUpdateTemplateEntryKHR; + +typedef struct VkDescriptorUpdateTemplateCreateInfoKHR { + VkStructureType sType; + void* pNext; + VkDescriptorUpdateTemplateCreateFlagsKHR flags; + uint32_t descriptorUpdateEntryCount; + const VkDescriptorUpdateTemplateEntryKHR* pDescriptorUpdateEntries; + VkDescriptorUpdateTemplateTypeKHR templateType; + VkDescriptorSetLayout descriptorSetLayout; + VkPipelineBindPoint pipelineBindPoint; + VkPipelineLayout pipelineLayout; + uint32_t set; +} VkDescriptorUpdateTemplateCreateInfoKHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateDescriptorUpdateTemplateKHR)(VkDevice device, const VkDescriptorUpdateTemplateCreateInfoKHR* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDescriptorUpdateTemplateKHR* pDescriptorUpdateTemplate); +typedef void (VKAPI_PTR *PFN_vkDestroyDescriptorUpdateTemplateKHR)(VkDevice device, VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, const VkAllocationCallbacks* pAllocator); +typedef void (VKAPI_PTR *PFN_vkUpdateDescriptorSetWithTemplateKHR)(VkDevice device, VkDescriptorSet descriptorSet, VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, const void* pData); +typedef void (VKAPI_PTR *PFN_vkCmdPushDescriptorSetWithTemplateKHR)(VkCommandBuffer commandBuffer, VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, VkPipelineLayout layout, uint32_t set, const void* pData); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceExternalImageFormatPropertiesNV( - VkPhysicalDevice physicalDevice, - VkFormat format, - VkImageType type, - VkImageTiling tiling, - VkImageUsageFlags usage, - VkImageCreateFlags flags, - VkExternalMemoryHandleTypeFlagsNV externalHandleType, - VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); +VKAPI_ATTR VkResult VKAPI_CALL vkCreateDescriptorUpdateTemplateKHR( + VkDevice device, + const VkDescriptorUpdateTemplateCreateInfoKHR* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkDescriptorUpdateTemplateKHR* pDescriptorUpdateTemplate); + +VKAPI_ATTR void VKAPI_CALL vkDestroyDescriptorUpdateTemplateKHR( + VkDevice device, + VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR void VKAPI_CALL vkUpdateDescriptorSetWithTemplateKHR( + VkDevice device, + VkDescriptorSet descriptorSet, + VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, + const void* pData); + +VKAPI_ATTR void VKAPI_CALL vkCmdPushDescriptorSetWithTemplateKHR( + VkCommandBuffer commandBuffer, + VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, + VkPipelineLayout layout, + uint32_t set, + const void* pData); #endif -#define VK_NV_external_memory 1 -#define VK_NV_EXTERNAL_MEMORY_SPEC_VERSION 1 -#define VK_NV_EXTERNAL_MEMORY_EXTENSION_NAME "VK_NV_external_memory" +#define VK_KHR_shared_presentable_image 1 +#define VK_KHR_SHARED_PRESENTABLE_IMAGE_SPEC_VERSION 1 +#define VK_KHR_SHARED_PRESENTABLE_IMAGE_EXTENSION_NAME "VK_KHR_shared_presentable_image" -typedef struct VkExternalMemoryImageCreateInfoNV { - VkStructureType sType; - const void* pNext; - VkExternalMemoryHandleTypeFlagsNV handleTypes; -} VkExternalMemoryImageCreateInfoNV; +typedef struct VkSharedPresentSurfaceCapabilitiesKHR { + VkStructureType sType; + void* pNext; + VkImageUsageFlags sharedPresentSupportedUsageFlags; +} VkSharedPresentSurfaceCapabilitiesKHR; -typedef struct VkExportMemoryAllocateInfoNV { + +typedef VkResult (VKAPI_PTR *PFN_vkGetSwapchainStatusKHR)(VkDevice device, VkSwapchainKHR swapchain); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetSwapchainStatusKHR( + VkDevice device, + VkSwapchainKHR swapchain); +#endif + +#define VK_KHR_external_fence_capabilities 1 +#define VK_KHR_EXTERNAL_FENCE_CAPABILITIES_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME "VK_KHR_external_fence_capabilities" + + +typedef enum VkExternalFenceHandleTypeFlagBitsKHR { + VK_EXTERNAL_FENCE_HANDLE_TYPE_OPAQUE_FD_BIT_KHR = 0x00000001, + VK_EXTERNAL_FENCE_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR = 0x00000002, + VK_EXTERNAL_FENCE_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_KHR = 0x00000004, + VK_EXTERNAL_FENCE_HANDLE_TYPE_SYNC_FD_BIT_KHR = 0x00000008, + VK_EXTERNAL_FENCE_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalFenceHandleTypeFlagBitsKHR; +typedef VkFlags VkExternalFenceHandleTypeFlagsKHR; + +typedef enum VkExternalFenceFeatureFlagBitsKHR { + VK_EXTERNAL_FENCE_FEATURE_EXPORTABLE_BIT_KHR = 0x00000001, + VK_EXTERNAL_FENCE_FEATURE_IMPORTABLE_BIT_KHR = 0x00000002, + VK_EXTERNAL_FENCE_FEATURE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalFenceFeatureFlagBitsKHR; +typedef VkFlags VkExternalFenceFeatureFlagsKHR; + +typedef struct VkPhysicalDeviceExternalFenceInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalFenceHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalFenceInfoKHR; + +typedef struct VkExternalFencePropertiesKHR { VkStructureType sType; - const void* pNext; - VkExternalMemoryHandleTypeFlagsNV handleTypes; -} VkExportMemoryAllocateInfoNV; + void* pNext; + VkExternalFenceHandleTypeFlagsKHR exportFromImportedHandleTypes; + VkExternalFenceHandleTypeFlagsKHR compatibleHandleTypes; + VkExternalFenceFeatureFlagsKHR externalFenceFeatures; +} VkExternalFencePropertiesKHR; +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalFencePropertiesKHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceExternalFenceInfoKHR* pExternalFenceInfo, VkExternalFencePropertiesKHR* pExternalFenceProperties); -#ifdef VK_USE_PLATFORM_WIN32_KHR -#define VK_NV_external_memory_win32 1 -#define VK_NV_EXTERNAL_MEMORY_WIN32_SPEC_VERSION 1 -#define VK_NV_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME "VK_NV_external_memory_win32" +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceExternalFencePropertiesKHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceExternalFenceInfoKHR* pExternalFenceInfo, + VkExternalFencePropertiesKHR* pExternalFenceProperties); +#endif -typedef struct VkImportMemoryWin32HandleInfoNV { +#define VK_KHR_external_fence 1 +#define VK_KHR_EXTERNAL_FENCE_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_EXTENSION_NAME "VK_KHR_external_fence" + + +typedef enum VkFenceImportFlagBitsKHR { + VK_FENCE_IMPORT_TEMPORARY_BIT_KHR = 0x00000001, + VK_FENCE_IMPORT_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkFenceImportFlagBitsKHR; +typedef VkFlags VkFenceImportFlagsKHR; + +typedef struct VkExportFenceCreateInfoKHR { VkStructureType sType; const void* pNext; - VkExternalMemoryHandleTypeFlagsNV handleType; - HANDLE handle; -} VkImportMemoryWin32HandleInfoNV; + VkExternalFenceHandleTypeFlagsKHR handleTypes; +} VkExportFenceCreateInfoKHR; -typedef struct VkExportMemoryWin32HandleInfoNV { + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_external_fence_win32 1 +#define VK_KHR_EXTERNAL_FENCE_WIN32_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_WIN32_EXTENSION_NAME "VK_KHR_external_fence_win32" + +typedef struct VkImportFenceWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkFenceImportFlagsKHR flags; + VkExternalFenceHandleTypeFlagBitsKHR handleType; + HANDLE handle; + LPCWSTR name; +} VkImportFenceWin32HandleInfoKHR; + +typedef struct VkExportFenceWin32HandleInfoKHR { VkStructureType sType; const void* pNext; const SECURITY_ATTRIBUTES* pAttributes; DWORD dwAccess; -} VkExportMemoryWin32HandleInfoNV; + LPCWSTR name; +} VkExportFenceWin32HandleInfoKHR; +typedef struct VkFenceGetWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkExternalFenceHandleTypeFlagBitsKHR handleType; +} VkFenceGetWin32HandleInfoKHR; -typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandleNV)(VkDevice device, VkDeviceMemory memory, VkExternalMemoryHandleTypeFlagsNV handleType, HANDLE* pHandle); + +typedef VkResult (VKAPI_PTR *PFN_vkImportFenceWin32HandleKHR)(VkDevice device, const VkImportFenceWin32HandleInfoKHR* pImportFenceWin32HandleInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetFenceWin32HandleKHR)(VkDevice device, const VkFenceGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandleNV( +VKAPI_ATTR VkResult VKAPI_CALL vkImportFenceWin32HandleKHR( VkDevice device, - VkDeviceMemory memory, - VkExternalMemoryHandleTypeFlagsNV handleType, + const VkImportFenceWin32HandleInfoKHR* pImportFenceWin32HandleInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetFenceWin32HandleKHR( + VkDevice device, + const VkFenceGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); #endif #endif /* VK_USE_PLATFORM_WIN32_KHR */ -#ifdef VK_USE_PLATFORM_WIN32_KHR -#define VK_NV_win32_keyed_mutex 1 -#define VK_NV_WIN32_KEYED_MUTEX_SPEC_VERSION 1 -#define VK_NV_WIN32_KEYED_MUTEX_EXTENSION_NAME "VK_NV_win32_keyed_mutex" +#define VK_KHR_external_fence_fd 1 +#define VK_KHR_EXTERNAL_FENCE_FD_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_FD_EXTENSION_NAME "VK_KHR_external_fence_fd" -typedef struct VkWin32KeyedMutexAcquireReleaseInfoNV { - VkStructureType sType; - const void* pNext; - uint32_t acquireCount; - const VkDeviceMemory* pAcquireSyncs; - const uint64_t* pAcquireKeys; - const uint32_t* pAcquireTimeoutMilliseconds; - uint32_t releaseCount; - const VkDeviceMemory* pReleaseSyncs; - const uint64_t* pReleaseKeys; -} VkWin32KeyedMutexAcquireReleaseInfoNV; +typedef struct VkImportFenceFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkFenceImportFlagsKHR flags; + VkExternalFenceHandleTypeFlagBitsKHR handleType; + int fd; +} VkImportFenceFdInfoKHR; +typedef struct VkFenceGetFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkExternalFenceHandleTypeFlagBitsKHR handleType; +} VkFenceGetFdInfoKHR; -#endif /* VK_USE_PLATFORM_WIN32_KHR */ -#define VK_EXT_validation_flags 1 -#define VK_EXT_VALIDATION_FLAGS_SPEC_VERSION 1 -#define VK_EXT_VALIDATION_FLAGS_EXTENSION_NAME "VK_EXT_validation_flags" +typedef VkResult (VKAPI_PTR *PFN_vkImportFenceFdKHR)(VkDevice device, const VkImportFenceFdInfoKHR* pImportFenceFdInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetFenceFdKHR)(VkDevice device, const VkFenceGetFdInfoKHR* pGetFdInfo, int* pFd); +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkImportFenceFdKHR( + VkDevice device, + const VkImportFenceFdInfoKHR* pImportFenceFdInfo); -typedef enum VkValidationCheckEXT { - VK_VALIDATION_CHECK_ALL_EXT = 0, - VK_VALIDATION_CHECK_BEGIN_RANGE_EXT = VK_VALIDATION_CHECK_ALL_EXT, - VK_VALIDATION_CHECK_END_RANGE_EXT = VK_VALIDATION_CHECK_ALL_EXT, - VK_VALIDATION_CHECK_RANGE_SIZE_EXT = (VK_VALIDATION_CHECK_ALL_EXT - VK_VALIDATION_CHECK_ALL_EXT + 1), - VK_VALIDATION_CHECK_MAX_ENUM_EXT = 0x7FFFFFFF -} VkValidationCheckEXT; +VKAPI_ATTR VkResult VKAPI_CALL vkGetFenceFdKHR( + VkDevice device, + const VkFenceGetFdInfoKHR* pGetFdInfo, + int* pFd); +#endif + +#define VK_KHR_get_surface_capabilities2 1 +#define VK_KHR_GET_SURFACE_CAPABILITIES_2_SPEC_VERSION 1 +#define VK_KHR_GET_SURFACE_CAPABILITIES_2_EXTENSION_NAME "VK_KHR_get_surface_capabilities2" + +typedef struct VkPhysicalDeviceSurfaceInfo2KHR { + VkStructureType sType; + const void* pNext; + VkSurfaceKHR surface; +} VkPhysicalDeviceSurfaceInfo2KHR; + +typedef struct VkSurfaceCapabilities2KHR { + VkStructureType sType; + void* pNext; + VkSurfaceCapabilitiesKHR surfaceCapabilities; +} VkSurfaceCapabilities2KHR; + +typedef struct VkSurfaceFormat2KHR { + VkStructureType sType; + void* pNext; + VkSurfaceFormatKHR surfaceFormat; +} VkSurfaceFormat2KHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceSurfaceCapabilities2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, VkSurfaceCapabilities2KHR* pSurfaceCapabilities); +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceSurfaceFormats2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, uint32_t* pSurfaceFormatCount, VkSurfaceFormat2KHR* pSurfaceFormats); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceSurfaceCapabilities2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, + VkSurfaceCapabilities2KHR* pSurfaceCapabilities); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceSurfaceFormats2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, + uint32_t* pSurfaceFormatCount, + VkSurfaceFormat2KHR* pSurfaceFormats); +#endif + +#define VK_KHR_variable_pointers 1 +#define VK_KHR_VARIABLE_POINTERS_SPEC_VERSION 1 +#define VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME "VK_KHR_variable_pointers" + +typedef struct VkPhysicalDeviceVariablePointerFeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 variablePointersStorageBuffer; + VkBool32 variablePointers; +} VkPhysicalDeviceVariablePointerFeaturesKHR; + + + +#define VK_KHR_dedicated_allocation 1 +#define VK_KHR_DEDICATED_ALLOCATION_SPEC_VERSION 3 +#define VK_KHR_DEDICATED_ALLOCATION_EXTENSION_NAME "VK_KHR_dedicated_allocation" + +typedef struct VkMemoryDedicatedRequirementsKHR { + VkStructureType sType; + void* pNext; + VkBool32 prefersDedicatedAllocation; + VkBool32 requiresDedicatedAllocation; +} VkMemoryDedicatedRequirementsKHR; + +typedef struct VkMemoryDedicatedAllocateInfoKHR { + VkStructureType sType; + const void* pNext; + VkImage image; + VkBuffer buffer; +} VkMemoryDedicatedAllocateInfoKHR; + + + +#define VK_KHR_storage_buffer_storage_class 1 +#define VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_SPEC_VERSION 1 +#define VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME "VK_KHR_storage_buffer_storage_class" + + +#define VK_KHR_relaxed_block_layout 1 +#define VK_KHR_RELAXED_BLOCK_LAYOUT_SPEC_VERSION 1 +#define VK_KHR_RELAXED_BLOCK_LAYOUT_EXTENSION_NAME "VK_KHR_relaxed_block_layout" + + +#define VK_KHR_get_memory_requirements2 1 +#define VK_KHR_GET_MEMORY_REQUIREMENTS_2_SPEC_VERSION 1 +#define VK_KHR_GET_MEMORY_REQUIREMENTS_2_EXTENSION_NAME "VK_KHR_get_memory_requirements2" + +typedef struct VkBufferMemoryRequirementsInfo2KHR { + VkStructureType sType; + const void* pNext; + VkBuffer buffer; +} VkBufferMemoryRequirementsInfo2KHR; + +typedef struct VkImageMemoryRequirementsInfo2KHR { + VkStructureType sType; + const void* pNext; + VkImage image; +} VkImageMemoryRequirementsInfo2KHR; + +typedef struct VkImageSparseMemoryRequirementsInfo2KHR { + VkStructureType sType; + const void* pNext; + VkImage image; +} VkImageSparseMemoryRequirementsInfo2KHR; + +typedef struct VkMemoryRequirements2KHR { + VkStructureType sType; + void* pNext; + VkMemoryRequirements memoryRequirements; +} VkMemoryRequirements2KHR; + +typedef struct VkSparseImageMemoryRequirements2KHR { + VkStructureType sType; + void* pNext; + VkSparseImageMemoryRequirements memoryRequirements; +} VkSparseImageMemoryRequirements2KHR; + + +typedef void (VKAPI_PTR *PFN_vkGetImageMemoryRequirements2KHR)(VkDevice device, const VkImageMemoryRequirementsInfo2KHR* pInfo, VkMemoryRequirements2KHR* pMemoryRequirements); +typedef void (VKAPI_PTR *PFN_vkGetBufferMemoryRequirements2KHR)(VkDevice device, const VkBufferMemoryRequirementsInfo2KHR* pInfo, VkMemoryRequirements2KHR* pMemoryRequirements); +typedef void (VKAPI_PTR *PFN_vkGetImageSparseMemoryRequirements2KHR)(VkDevice device, const VkImageSparseMemoryRequirementsInfo2KHR* pInfo, uint32_t* pSparseMemoryRequirementCount, VkSparseImageMemoryRequirements2KHR* pSparseMemoryRequirements); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetImageMemoryRequirements2KHR( + VkDevice device, + const VkImageMemoryRequirementsInfo2KHR* pInfo, + VkMemoryRequirements2KHR* pMemoryRequirements); + +VKAPI_ATTR void VKAPI_CALL vkGetBufferMemoryRequirements2KHR( + VkDevice device, + const VkBufferMemoryRequirementsInfo2KHR* pInfo, + VkMemoryRequirements2KHR* pMemoryRequirements); + +VKAPI_ATTR void VKAPI_CALL vkGetImageSparseMemoryRequirements2KHR( + VkDevice device, + const VkImageSparseMemoryRequirementsInfo2KHR* pInfo, + uint32_t* pSparseMemoryRequirementCount, + VkSparseImageMemoryRequirements2KHR* pSparseMemoryRequirements); +#endif + +#define VK_EXT_debug_report 1 +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkDebugReportCallbackEXT) + +#define VK_EXT_DEBUG_REPORT_SPEC_VERSION 8 +#define VK_EXT_DEBUG_REPORT_EXTENSION_NAME "VK_EXT_debug_report" +#define VK_STRUCTURE_TYPE_DEBUG_REPORT_CREATE_INFO_EXT VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT +#define VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_CALLBACK_EXT_EXT + + +typedef enum VkDebugReportObjectTypeEXT { + VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT = 0, + VK_DEBUG_REPORT_OBJECT_TYPE_INSTANCE_EXT = 1, + VK_DEBUG_REPORT_OBJECT_TYPE_PHYSICAL_DEVICE_EXT = 2, + VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_EXT = 3, + VK_DEBUG_REPORT_OBJECT_TYPE_QUEUE_EXT = 4, + VK_DEBUG_REPORT_OBJECT_TYPE_SEMAPHORE_EXT = 5, + VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT = 6, + VK_DEBUG_REPORT_OBJECT_TYPE_FENCE_EXT = 7, + VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_MEMORY_EXT = 8, + VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_EXT = 9, + VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_EXT = 10, + VK_DEBUG_REPORT_OBJECT_TYPE_EVENT_EXT = 11, + VK_DEBUG_REPORT_OBJECT_TYPE_QUERY_POOL_EXT = 12, + VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_VIEW_EXT = 13, + VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_VIEW_EXT = 14, + VK_DEBUG_REPORT_OBJECT_TYPE_SHADER_MODULE_EXT = 15, + VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_CACHE_EXT = 16, + VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_LAYOUT_EXT = 17, + VK_DEBUG_REPORT_OBJECT_TYPE_RENDER_PASS_EXT = 18, + VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_EXT = 19, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT_EXT = 20, + VK_DEBUG_REPORT_OBJECT_TYPE_SAMPLER_EXT = 21, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_POOL_EXT = 22, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_EXT = 23, + VK_DEBUG_REPORT_OBJECT_TYPE_FRAMEBUFFER_EXT = 24, + VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_POOL_EXT = 25, + VK_DEBUG_REPORT_OBJECT_TYPE_SURFACE_KHR_EXT = 26, + VK_DEBUG_REPORT_OBJECT_TYPE_SWAPCHAIN_KHR_EXT = 27, + VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_CALLBACK_EXT_EXT = 28, + VK_DEBUG_REPORT_OBJECT_TYPE_DISPLAY_KHR_EXT = 29, + VK_DEBUG_REPORT_OBJECT_TYPE_DISPLAY_MODE_KHR_EXT = 30, + VK_DEBUG_REPORT_OBJECT_TYPE_OBJECT_TABLE_NVX_EXT = 31, + VK_DEBUG_REPORT_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX_EXT = 32, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_KHR_EXT = 1000085000, + VK_DEBUG_REPORT_OBJECT_TYPE_BEGIN_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, + VK_DEBUG_REPORT_OBJECT_TYPE_END_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX_EXT, + VK_DEBUG_REPORT_OBJECT_TYPE_RANGE_SIZE_EXT = (VK_DEBUG_REPORT_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX_EXT - VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT + 1), + VK_DEBUG_REPORT_OBJECT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDebugReportObjectTypeEXT; + + +typedef enum VkDebugReportFlagBitsEXT { + VK_DEBUG_REPORT_INFORMATION_BIT_EXT = 0x00000001, + VK_DEBUG_REPORT_WARNING_BIT_EXT = 0x00000002, + VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT = 0x00000004, + VK_DEBUG_REPORT_ERROR_BIT_EXT = 0x00000008, + VK_DEBUG_REPORT_DEBUG_BIT_EXT = 0x00000010, + VK_DEBUG_REPORT_FLAG_BITS_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDebugReportFlagBitsEXT; +typedef VkFlags VkDebugReportFlagsEXT; + +typedef VkBool32 (VKAPI_PTR *PFN_vkDebugReportCallbackEXT)( + VkDebugReportFlagsEXT flags, + VkDebugReportObjectTypeEXT objectType, + uint64_t object, + size_t location, + int32_t messageCode, + const char* pLayerPrefix, + const char* pMessage, + void* pUserData); + +typedef struct VkDebugReportCallbackCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkDebugReportFlagsEXT flags; + PFN_vkDebugReportCallbackEXT pfnCallback; + void* pUserData; +} VkDebugReportCallbackCreateInfoEXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateDebugReportCallbackEXT)(VkInstance instance, const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugReportCallbackEXT* pCallback); +typedef void (VKAPI_PTR *PFN_vkDestroyDebugReportCallbackEXT)(VkInstance instance, VkDebugReportCallbackEXT callback, const VkAllocationCallbacks* pAllocator); +typedef void (VKAPI_PTR *PFN_vkDebugReportMessageEXT)(VkInstance instance, VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objectType, uint64_t object, size_t location, int32_t messageCode, const char* pLayerPrefix, const char* pMessage); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateDebugReportCallbackEXT( + VkInstance instance, + const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkDebugReportCallbackEXT* pCallback); + +VKAPI_ATTR void VKAPI_CALL vkDestroyDebugReportCallbackEXT( + VkInstance instance, + VkDebugReportCallbackEXT callback, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR void VKAPI_CALL vkDebugReportMessageEXT( + VkInstance instance, + VkDebugReportFlagsEXT flags, + VkDebugReportObjectTypeEXT objectType, + uint64_t object, + size_t location, + int32_t messageCode, + const char* pLayerPrefix, + const char* pMessage); +#endif + +#define VK_NV_glsl_shader 1 +#define VK_NV_GLSL_SHADER_SPEC_VERSION 1 +#define VK_NV_GLSL_SHADER_EXTENSION_NAME "VK_NV_glsl_shader" + + +#define VK_EXT_depth_range_unrestricted 1 +#define VK_EXT_DEPTH_RANGE_UNRESTRICTED_SPEC_VERSION 1 +#define VK_EXT_DEPTH_RANGE_UNRESTRICTED_EXTENSION_NAME "VK_EXT_depth_range_unrestricted" + + +#define VK_IMG_filter_cubic 1 +#define VK_IMG_FILTER_CUBIC_SPEC_VERSION 1 +#define VK_IMG_FILTER_CUBIC_EXTENSION_NAME "VK_IMG_filter_cubic" + + +#define VK_AMD_rasterization_order 1 +#define VK_AMD_RASTERIZATION_ORDER_SPEC_VERSION 1 +#define VK_AMD_RASTERIZATION_ORDER_EXTENSION_NAME "VK_AMD_rasterization_order" + + +typedef enum VkRasterizationOrderAMD { + VK_RASTERIZATION_ORDER_STRICT_AMD = 0, + VK_RASTERIZATION_ORDER_RELAXED_AMD = 1, + VK_RASTERIZATION_ORDER_BEGIN_RANGE_AMD = VK_RASTERIZATION_ORDER_STRICT_AMD, + VK_RASTERIZATION_ORDER_END_RANGE_AMD = VK_RASTERIZATION_ORDER_RELAXED_AMD, + VK_RASTERIZATION_ORDER_RANGE_SIZE_AMD = (VK_RASTERIZATION_ORDER_RELAXED_AMD - VK_RASTERIZATION_ORDER_STRICT_AMD + 1), + VK_RASTERIZATION_ORDER_MAX_ENUM_AMD = 0x7FFFFFFF +} VkRasterizationOrderAMD; + +typedef struct VkPipelineRasterizationStateRasterizationOrderAMD { + VkStructureType sType; + const void* pNext; + VkRasterizationOrderAMD rasterizationOrder; +} VkPipelineRasterizationStateRasterizationOrderAMD; + + + +#define VK_AMD_shader_trinary_minmax 1 +#define VK_AMD_SHADER_TRINARY_MINMAX_SPEC_VERSION 1 +#define VK_AMD_SHADER_TRINARY_MINMAX_EXTENSION_NAME "VK_AMD_shader_trinary_minmax" + + +#define VK_AMD_shader_explicit_vertex_parameter 1 +#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_SPEC_VERSION 1 +#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_EXTENSION_NAME "VK_AMD_shader_explicit_vertex_parameter" + + +#define VK_EXT_debug_marker 1 +#define VK_EXT_DEBUG_MARKER_SPEC_VERSION 4 +#define VK_EXT_DEBUG_MARKER_EXTENSION_NAME "VK_EXT_debug_marker" + +typedef struct VkDebugMarkerObjectNameInfoEXT { + VkStructureType sType; + const void* pNext; + VkDebugReportObjectTypeEXT objectType; + uint64_t object; + const char* pObjectName; +} VkDebugMarkerObjectNameInfoEXT; + +typedef struct VkDebugMarkerObjectTagInfoEXT { + VkStructureType sType; + const void* pNext; + VkDebugReportObjectTypeEXT objectType; + uint64_t object; + uint64_t tagName; + size_t tagSize; + const void* pTag; +} VkDebugMarkerObjectTagInfoEXT; + +typedef struct VkDebugMarkerMarkerInfoEXT { + VkStructureType sType; + const void* pNext; + const char* pMarkerName; + float color[4]; +} VkDebugMarkerMarkerInfoEXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectTagEXT)(VkDevice device, const VkDebugMarkerObjectTagInfoEXT* pTagInfo); +typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectNameEXT)(VkDevice device, const VkDebugMarkerObjectNameInfoEXT* pNameInfo); +typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerBeginEXT)(VkCommandBuffer commandBuffer, const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerEndEXT)(VkCommandBuffer commandBuffer); +typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerInsertEXT)(VkCommandBuffer commandBuffer, const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectTagEXT( + VkDevice device, + const VkDebugMarkerObjectTagInfoEXT* pTagInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectNameEXT( + VkDevice device, + const VkDebugMarkerObjectNameInfoEXT* pNameInfo); + +VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerBeginEXT( + VkCommandBuffer commandBuffer, + const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); + +VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerEndEXT( + VkCommandBuffer commandBuffer); + +VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerInsertEXT( + VkCommandBuffer commandBuffer, + const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +#endif + +#define VK_AMD_gcn_shader 1 +#define VK_AMD_GCN_SHADER_SPEC_VERSION 1 +#define VK_AMD_GCN_SHADER_EXTENSION_NAME "VK_AMD_gcn_shader" + + +#define VK_NV_dedicated_allocation 1 +#define VK_NV_DEDICATED_ALLOCATION_SPEC_VERSION 1 +#define VK_NV_DEDICATED_ALLOCATION_EXTENSION_NAME "VK_NV_dedicated_allocation" + +typedef struct VkDedicatedAllocationImageCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkBool32 dedicatedAllocation; +} VkDedicatedAllocationImageCreateInfoNV; + +typedef struct VkDedicatedAllocationBufferCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkBool32 dedicatedAllocation; +} VkDedicatedAllocationBufferCreateInfoNV; + +typedef struct VkDedicatedAllocationMemoryAllocateInfoNV { + VkStructureType sType; + const void* pNext; + VkImage image; + VkBuffer buffer; +} VkDedicatedAllocationMemoryAllocateInfoNV; + + + +#define VK_AMD_draw_indirect_count 1 +#define VK_AMD_DRAW_INDIRECT_COUNT_SPEC_VERSION 1 +#define VK_AMD_DRAW_INDIRECT_COUNT_EXTENSION_NAME "VK_AMD_draw_indirect_count" + +typedef void (VKAPI_PTR *PFN_vkCmdDrawIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); +typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexedIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndirectCountAMD( + VkCommandBuffer commandBuffer, + VkBuffer buffer, + VkDeviceSize offset, + VkBuffer countBuffer, + VkDeviceSize countBufferOffset, + uint32_t maxDrawCount, + uint32_t stride); + +VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndexedIndirectCountAMD( + VkCommandBuffer commandBuffer, + VkBuffer buffer, + VkDeviceSize offset, + VkBuffer countBuffer, + VkDeviceSize countBufferOffset, + uint32_t maxDrawCount, + uint32_t stride); +#endif + +#define VK_AMD_negative_viewport_height 1 +#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_SPEC_VERSION 1 +#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_EXTENSION_NAME "VK_AMD_negative_viewport_height" + + +#define VK_AMD_gpu_shader_half_float 1 +#define VK_AMD_GPU_SHADER_HALF_FLOAT_SPEC_VERSION 1 +#define VK_AMD_GPU_SHADER_HALF_FLOAT_EXTENSION_NAME "VK_AMD_gpu_shader_half_float" + + +#define VK_AMD_shader_ballot 1 +#define VK_AMD_SHADER_BALLOT_SPEC_VERSION 1 +#define VK_AMD_SHADER_BALLOT_EXTENSION_NAME "VK_AMD_shader_ballot" + + +#define VK_AMD_texture_gather_bias_lod 1 +#define VK_AMD_TEXTURE_GATHER_BIAS_LOD_SPEC_VERSION 1 +#define VK_AMD_TEXTURE_GATHER_BIAS_LOD_EXTENSION_NAME "VK_AMD_texture_gather_bias_lod" + +typedef struct VkTextureLODGatherFormatPropertiesAMD { + VkStructureType sType; + void* pNext; + VkBool32 supportsTextureGatherLODBiasAMD; +} VkTextureLODGatherFormatPropertiesAMD; + + + +#define VK_KHX_multiview 1 +#define VK_KHX_MULTIVIEW_SPEC_VERSION 1 +#define VK_KHX_MULTIVIEW_EXTENSION_NAME "VK_KHX_multiview" + +typedef struct VkRenderPassMultiviewCreateInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t subpassCount; + const uint32_t* pViewMasks; + uint32_t dependencyCount; + const int32_t* pViewOffsets; + uint32_t correlationMaskCount; + const uint32_t* pCorrelationMasks; +} VkRenderPassMultiviewCreateInfoKHX; + +typedef struct VkPhysicalDeviceMultiviewFeaturesKHX { + VkStructureType sType; + void* pNext; + VkBool32 multiview; + VkBool32 multiviewGeometryShader; + VkBool32 multiviewTessellationShader; +} VkPhysicalDeviceMultiviewFeaturesKHX; + +typedef struct VkPhysicalDeviceMultiviewPropertiesKHX { + VkStructureType sType; + void* pNext; + uint32_t maxMultiviewViewCount; + uint32_t maxMultiviewInstanceIndex; +} VkPhysicalDeviceMultiviewPropertiesKHX; + + + +#define VK_IMG_format_pvrtc 1 +#define VK_IMG_FORMAT_PVRTC_SPEC_VERSION 1 +#define VK_IMG_FORMAT_PVRTC_EXTENSION_NAME "VK_IMG_format_pvrtc" + + +#define VK_NV_external_memory_capabilities 1 +#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_SPEC_VERSION 1 +#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME "VK_NV_external_memory_capabilities" + + +typedef enum VkExternalMemoryHandleTypeFlagBitsNV { + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_NV = 0x00000001, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_NV = 0x00000002, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_BIT_NV = 0x00000004, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_KMT_BIT_NV = 0x00000008, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF +} VkExternalMemoryHandleTypeFlagBitsNV; +typedef VkFlags VkExternalMemoryHandleTypeFlagsNV; + +typedef enum VkExternalMemoryFeatureFlagBitsNV { + VK_EXTERNAL_MEMORY_FEATURE_DEDICATED_ONLY_BIT_NV = 0x00000001, + VK_EXTERNAL_MEMORY_FEATURE_EXPORTABLE_BIT_NV = 0x00000002, + VK_EXTERNAL_MEMORY_FEATURE_IMPORTABLE_BIT_NV = 0x00000004, + VK_EXTERNAL_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF +} VkExternalMemoryFeatureFlagBitsNV; +typedef VkFlags VkExternalMemoryFeatureFlagsNV; + +typedef struct VkExternalImageFormatPropertiesNV { + VkImageFormatProperties imageFormatProperties; + VkExternalMemoryFeatureFlagsNV externalMemoryFeatures; + VkExternalMemoryHandleTypeFlagsNV exportFromImportedHandleTypes; + VkExternalMemoryHandleTypeFlagsNV compatibleHandleTypes; +} VkExternalImageFormatPropertiesNV; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalImageFormatPropertiesNV)(VkPhysicalDevice physicalDevice, VkFormat format, VkImageType type, VkImageTiling tiling, VkImageUsageFlags usage, VkImageCreateFlags flags, VkExternalMemoryHandleTypeFlagsNV externalHandleType, VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceExternalImageFormatPropertiesNV( + VkPhysicalDevice physicalDevice, + VkFormat format, + VkImageType type, + VkImageTiling tiling, + VkImageUsageFlags usage, + VkImageCreateFlags flags, + VkExternalMemoryHandleTypeFlagsNV externalHandleType, + VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); +#endif + +#define VK_NV_external_memory 1 +#define VK_NV_EXTERNAL_MEMORY_SPEC_VERSION 1 +#define VK_NV_EXTERNAL_MEMORY_EXTENSION_NAME "VK_NV_external_memory" + +typedef struct VkExternalMemoryImageCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsNV handleTypes; +} VkExternalMemoryImageCreateInfoNV; + +typedef struct VkExportMemoryAllocateInfoNV { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsNV handleTypes; +} VkExportMemoryAllocateInfoNV; + + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_NV_external_memory_win32 1 +#define VK_NV_EXTERNAL_MEMORY_WIN32_SPEC_VERSION 1 +#define VK_NV_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME "VK_NV_external_memory_win32" + +typedef struct VkImportMemoryWin32HandleInfoNV { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsNV handleType; + HANDLE handle; +} VkImportMemoryWin32HandleInfoNV; + +typedef struct VkExportMemoryWin32HandleInfoNV { + VkStructureType sType; + const void* pNext; + const SECURITY_ATTRIBUTES* pAttributes; + DWORD dwAccess; +} VkExportMemoryWin32HandleInfoNV; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandleNV)(VkDevice device, VkDeviceMemory memory, VkExternalMemoryHandleTypeFlagsNV handleType, HANDLE* pHandle); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandleNV( + VkDevice device, + VkDeviceMemory memory, + VkExternalMemoryHandleTypeFlagsNV handleType, + HANDLE* pHandle); +#endif +#endif /* VK_USE_PLATFORM_WIN32_KHR */ + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_NV_win32_keyed_mutex 1 +#define VK_NV_WIN32_KEYED_MUTEX_SPEC_VERSION 1 +#define VK_NV_WIN32_KEYED_MUTEX_EXTENSION_NAME "VK_NV_win32_keyed_mutex" + +typedef struct VkWin32KeyedMutexAcquireReleaseInfoNV { + VkStructureType sType; + const void* pNext; + uint32_t acquireCount; + const VkDeviceMemory* pAcquireSyncs; + const uint64_t* pAcquireKeys; + const uint32_t* pAcquireTimeoutMilliseconds; + uint32_t releaseCount; + const VkDeviceMemory* pReleaseSyncs; + const uint64_t* pReleaseKeys; +} VkWin32KeyedMutexAcquireReleaseInfoNV; + + +#endif /* VK_USE_PLATFORM_WIN32_KHR */ + +#define VK_KHX_device_group 1 +#define VK_MAX_DEVICE_GROUP_SIZE_KHX 32 +#define VK_KHX_DEVICE_GROUP_SPEC_VERSION 1 +#define VK_KHX_DEVICE_GROUP_EXTENSION_NAME "VK_KHX_device_group" + + +typedef enum VkPeerMemoryFeatureFlagBitsKHX { + VK_PEER_MEMORY_FEATURE_COPY_SRC_BIT_KHX = 0x00000001, + VK_PEER_MEMORY_FEATURE_COPY_DST_BIT_KHX = 0x00000002, + VK_PEER_MEMORY_FEATURE_GENERIC_SRC_BIT_KHX = 0x00000004, + VK_PEER_MEMORY_FEATURE_GENERIC_DST_BIT_KHX = 0x00000008, + VK_PEER_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_KHX = 0x7FFFFFFF +} VkPeerMemoryFeatureFlagBitsKHX; +typedef VkFlags VkPeerMemoryFeatureFlagsKHX; + +typedef enum VkMemoryAllocateFlagBitsKHX { + VK_MEMORY_ALLOCATE_DEVICE_MASK_BIT_KHX = 0x00000001, + VK_MEMORY_ALLOCATE_FLAG_BITS_MAX_ENUM_KHX = 0x7FFFFFFF +} VkMemoryAllocateFlagBitsKHX; +typedef VkFlags VkMemoryAllocateFlagsKHX; + +typedef enum VkDeviceGroupPresentModeFlagBitsKHX { + VK_DEVICE_GROUP_PRESENT_MODE_LOCAL_BIT_KHX = 0x00000001, + VK_DEVICE_GROUP_PRESENT_MODE_REMOTE_BIT_KHX = 0x00000002, + VK_DEVICE_GROUP_PRESENT_MODE_SUM_BIT_KHX = 0x00000004, + VK_DEVICE_GROUP_PRESENT_MODE_LOCAL_MULTI_DEVICE_BIT_KHX = 0x00000008, + VK_DEVICE_GROUP_PRESENT_MODE_FLAG_BITS_MAX_ENUM_KHX = 0x7FFFFFFF +} VkDeviceGroupPresentModeFlagBitsKHX; +typedef VkFlags VkDeviceGroupPresentModeFlagsKHX; + +typedef struct VkMemoryAllocateFlagsInfoKHX { + VkStructureType sType; + const void* pNext; + VkMemoryAllocateFlagsKHX flags; + uint32_t deviceMask; +} VkMemoryAllocateFlagsInfoKHX; + +typedef struct VkBindBufferMemoryInfoKHX { + VkStructureType sType; + const void* pNext; + VkBuffer buffer; + VkDeviceMemory memory; + VkDeviceSize memoryOffset; + uint32_t deviceIndexCount; + const uint32_t* pDeviceIndices; +} VkBindBufferMemoryInfoKHX; + +typedef struct VkBindImageMemoryInfoKHX { + VkStructureType sType; + const void* pNext; + VkImage image; + VkDeviceMemory memory; + VkDeviceSize memoryOffset; + uint32_t deviceIndexCount; + const uint32_t* pDeviceIndices; + uint32_t SFRRectCount; + const VkRect2D* pSFRRects; +} VkBindImageMemoryInfoKHX; + +typedef struct VkDeviceGroupRenderPassBeginInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t deviceMask; + uint32_t deviceRenderAreaCount; + const VkRect2D* pDeviceRenderAreas; +} VkDeviceGroupRenderPassBeginInfoKHX; + +typedef struct VkDeviceGroupCommandBufferBeginInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t deviceMask; +} VkDeviceGroupCommandBufferBeginInfoKHX; + +typedef struct VkDeviceGroupSubmitInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t waitSemaphoreCount; + const uint32_t* pWaitSemaphoreDeviceIndices; + uint32_t commandBufferCount; + const uint32_t* pCommandBufferDeviceMasks; + uint32_t signalSemaphoreCount; + const uint32_t* pSignalSemaphoreDeviceIndices; +} VkDeviceGroupSubmitInfoKHX; + +typedef struct VkDeviceGroupBindSparseInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t resourceDeviceIndex; + uint32_t memoryDeviceIndex; +} VkDeviceGroupBindSparseInfoKHX; + +typedef struct VkDeviceGroupPresentCapabilitiesKHX { + VkStructureType sType; + const void* pNext; + uint32_t presentMask[VK_MAX_DEVICE_GROUP_SIZE_KHX]; + VkDeviceGroupPresentModeFlagsKHX modes; +} VkDeviceGroupPresentCapabilitiesKHX; + +typedef struct VkImageSwapchainCreateInfoKHX { + VkStructureType sType; + const void* pNext; + VkSwapchainKHR swapchain; +} VkImageSwapchainCreateInfoKHX; + +typedef struct VkBindImageMemorySwapchainInfoKHX { + VkStructureType sType; + const void* pNext; + VkSwapchainKHR swapchain; + uint32_t imageIndex; +} VkBindImageMemorySwapchainInfoKHX; + +typedef struct VkAcquireNextImageInfoKHX { + VkStructureType sType; + const void* pNext; + VkSwapchainKHR swapchain; + uint64_t timeout; + VkSemaphore semaphore; + VkFence fence; + uint32_t deviceMask; +} VkAcquireNextImageInfoKHX; + +typedef struct VkDeviceGroupPresentInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t swapchainCount; + const uint32_t* pDeviceMasks; + VkDeviceGroupPresentModeFlagBitsKHX mode; +} VkDeviceGroupPresentInfoKHX; + +typedef struct VkDeviceGroupSwapchainCreateInfoKHX { + VkStructureType sType; + const void* pNext; + VkDeviceGroupPresentModeFlagsKHX modes; +} VkDeviceGroupSwapchainCreateInfoKHX; + + +typedef void (VKAPI_PTR *PFN_vkGetDeviceGroupPeerMemoryFeaturesKHX)(VkDevice device, uint32_t heapIndex, uint32_t localDeviceIndex, uint32_t remoteDeviceIndex, VkPeerMemoryFeatureFlagsKHX* pPeerMemoryFeatures); +typedef VkResult (VKAPI_PTR *PFN_vkBindBufferMemory2KHX)(VkDevice device, uint32_t bindInfoCount, const VkBindBufferMemoryInfoKHX* pBindInfos); +typedef VkResult (VKAPI_PTR *PFN_vkBindImageMemory2KHX)(VkDevice device, uint32_t bindInfoCount, const VkBindImageMemoryInfoKHX* pBindInfos); +typedef void (VKAPI_PTR *PFN_vkCmdSetDeviceMaskKHX)(VkCommandBuffer commandBuffer, uint32_t deviceMask); +typedef VkResult (VKAPI_PTR *PFN_vkGetDeviceGroupPresentCapabilitiesKHX)(VkDevice device, VkDeviceGroupPresentCapabilitiesKHX* pDeviceGroupPresentCapabilities); +typedef VkResult (VKAPI_PTR *PFN_vkGetDeviceGroupSurfacePresentModesKHX)(VkDevice device, VkSurfaceKHR surface, VkDeviceGroupPresentModeFlagsKHX* pModes); +typedef VkResult (VKAPI_PTR *PFN_vkAcquireNextImage2KHX)(VkDevice device, const VkAcquireNextImageInfoKHX* pAcquireInfo, uint32_t* pImageIndex); +typedef void (VKAPI_PTR *PFN_vkCmdDispatchBaseKHX)(VkCommandBuffer commandBuffer, uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ); +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDevicePresentRectanglesKHX)(VkPhysicalDevice physicalDevice, VkSurfaceKHR surface, uint32_t* pRectCount, VkRect2D* pRects); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetDeviceGroupPeerMemoryFeaturesKHX( + VkDevice device, + uint32_t heapIndex, + uint32_t localDeviceIndex, + uint32_t remoteDeviceIndex, + VkPeerMemoryFeatureFlagsKHX* pPeerMemoryFeatures); + +VKAPI_ATTR VkResult VKAPI_CALL vkBindBufferMemory2KHX( + VkDevice device, + uint32_t bindInfoCount, + const VkBindBufferMemoryInfoKHX* pBindInfos); + +VKAPI_ATTR VkResult VKAPI_CALL vkBindImageMemory2KHX( + VkDevice device, + uint32_t bindInfoCount, + const VkBindImageMemoryInfoKHX* pBindInfos); + +VKAPI_ATTR void VKAPI_CALL vkCmdSetDeviceMaskKHX( + VkCommandBuffer commandBuffer, + uint32_t deviceMask); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetDeviceGroupPresentCapabilitiesKHX( + VkDevice device, + VkDeviceGroupPresentCapabilitiesKHX* pDeviceGroupPresentCapabilities); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetDeviceGroupSurfacePresentModesKHX( + VkDevice device, + VkSurfaceKHR surface, + VkDeviceGroupPresentModeFlagsKHX* pModes); + +VKAPI_ATTR VkResult VKAPI_CALL vkAcquireNextImage2KHX( + VkDevice device, + const VkAcquireNextImageInfoKHX* pAcquireInfo, + uint32_t* pImageIndex); + +VKAPI_ATTR void VKAPI_CALL vkCmdDispatchBaseKHX( + VkCommandBuffer commandBuffer, + uint32_t baseGroupX, + uint32_t baseGroupY, + uint32_t baseGroupZ, + uint32_t groupCountX, + uint32_t groupCountY, + uint32_t groupCountZ); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDevicePresentRectanglesKHX( + VkPhysicalDevice physicalDevice, + VkSurfaceKHR surface, + uint32_t* pRectCount, + VkRect2D* pRects); +#endif + +#define VK_EXT_validation_flags 1 +#define VK_EXT_VALIDATION_FLAGS_SPEC_VERSION 1 +#define VK_EXT_VALIDATION_FLAGS_EXTENSION_NAME "VK_EXT_validation_flags" + + +typedef enum VkValidationCheckEXT { + VK_VALIDATION_CHECK_ALL_EXT = 0, + VK_VALIDATION_CHECK_SHADERS_EXT = 1, + VK_VALIDATION_CHECK_BEGIN_RANGE_EXT = VK_VALIDATION_CHECK_ALL_EXT, + VK_VALIDATION_CHECK_END_RANGE_EXT = VK_VALIDATION_CHECK_SHADERS_EXT, + VK_VALIDATION_CHECK_RANGE_SIZE_EXT = (VK_VALIDATION_CHECK_SHADERS_EXT - VK_VALIDATION_CHECK_ALL_EXT + 1), + VK_VALIDATION_CHECK_MAX_ENUM_EXT = 0x7FFFFFFF +} VkValidationCheckEXT; typedef struct VkValidationFlagsEXT { VkStructureType sType; const void* pNext; - uint32_t disabledValidationCheckCount; - VkValidationCheckEXT* pDisabledValidationChecks; -} VkValidationFlagsEXT; + uint32_t disabledValidationCheckCount; + VkValidationCheckEXT* pDisabledValidationChecks; +} VkValidationFlagsEXT; + + + +#ifdef VK_USE_PLATFORM_VI_NN +#define VK_NN_vi_surface 1 +#define VK_NN_VI_SURFACE_SPEC_VERSION 1 +#define VK_NN_VI_SURFACE_EXTENSION_NAME "VK_NN_vi_surface" + +typedef VkFlags VkViSurfaceCreateFlagsNN; + +typedef struct VkViSurfaceCreateInfoNN { + VkStructureType sType; + const void* pNext; + VkViSurfaceCreateFlagsNN flags; + void* window; +} VkViSurfaceCreateInfoNN; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateViSurfaceNN)(VkInstance instance, const VkViSurfaceCreateInfoNN* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkSurfaceKHR* pSurface); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateViSurfaceNN( + VkInstance instance, + const VkViSurfaceCreateInfoNN* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkSurfaceKHR* pSurface); +#endif +#endif /* VK_USE_PLATFORM_VI_NN */ + +#define VK_EXT_shader_subgroup_ballot 1 +#define VK_EXT_SHADER_SUBGROUP_BALLOT_SPEC_VERSION 1 +#define VK_EXT_SHADER_SUBGROUP_BALLOT_EXTENSION_NAME "VK_EXT_shader_subgroup_ballot" + + +#define VK_EXT_shader_subgroup_vote 1 +#define VK_EXT_SHADER_SUBGROUP_VOTE_SPEC_VERSION 1 +#define VK_EXT_SHADER_SUBGROUP_VOTE_EXTENSION_NAME "VK_EXT_shader_subgroup_vote" + + +#define VK_KHX_device_group_creation 1 +#define VK_KHX_DEVICE_GROUP_CREATION_SPEC_VERSION 1 +#define VK_KHX_DEVICE_GROUP_CREATION_EXTENSION_NAME "VK_KHX_device_group_creation" + +typedef struct VkPhysicalDeviceGroupPropertiesKHX { + VkStructureType sType; + void* pNext; + uint32_t physicalDeviceCount; + VkPhysicalDevice physicalDevices[VK_MAX_DEVICE_GROUP_SIZE_KHX]; + VkBool32 subsetAllocation; +} VkPhysicalDeviceGroupPropertiesKHX; + +typedef struct VkDeviceGroupDeviceCreateInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t physicalDeviceCount; + const VkPhysicalDevice* pPhysicalDevices; +} VkDeviceGroupDeviceCreateInfoKHX; + + +typedef VkResult (VKAPI_PTR *PFN_vkEnumeratePhysicalDeviceGroupsKHX)(VkInstance instance, uint32_t* pPhysicalDeviceGroupCount, VkPhysicalDeviceGroupPropertiesKHX* pPhysicalDeviceGroupProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkEnumeratePhysicalDeviceGroupsKHX( + VkInstance instance, + uint32_t* pPhysicalDeviceGroupCount, + VkPhysicalDeviceGroupPropertiesKHX* pPhysicalDeviceGroupProperties); +#endif + +#define VK_NVX_device_generated_commands 1 +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkObjectTableNVX) +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkIndirectCommandsLayoutNVX) + +#define VK_NVX_DEVICE_GENERATED_COMMANDS_SPEC_VERSION 3 +#define VK_NVX_DEVICE_GENERATED_COMMANDS_EXTENSION_NAME "VK_NVX_device_generated_commands" + + +typedef enum VkIndirectCommandsTokenTypeNVX { + VK_INDIRECT_COMMANDS_TOKEN_TYPE_PIPELINE_NVX = 0, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DESCRIPTOR_SET_NVX = 1, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_INDEX_BUFFER_NVX = 2, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_VERTEX_BUFFER_NVX = 3, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_PUSH_CONSTANT_NVX = 4, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DRAW_INDEXED_NVX = 5, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DRAW_NVX = 6, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DISPATCH_NVX = 7, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_BEGIN_RANGE_NVX = VK_INDIRECT_COMMANDS_TOKEN_TYPE_PIPELINE_NVX, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_END_RANGE_NVX = VK_INDIRECT_COMMANDS_TOKEN_TYPE_DISPATCH_NVX, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_RANGE_SIZE_NVX = (VK_INDIRECT_COMMANDS_TOKEN_TYPE_DISPATCH_NVX - VK_INDIRECT_COMMANDS_TOKEN_TYPE_PIPELINE_NVX + 1), + VK_INDIRECT_COMMANDS_TOKEN_TYPE_MAX_ENUM_NVX = 0x7FFFFFFF +} VkIndirectCommandsTokenTypeNVX; + +typedef enum VkObjectEntryTypeNVX { + VK_OBJECT_ENTRY_TYPE_DESCRIPTOR_SET_NVX = 0, + VK_OBJECT_ENTRY_TYPE_PIPELINE_NVX = 1, + VK_OBJECT_ENTRY_TYPE_INDEX_BUFFER_NVX = 2, + VK_OBJECT_ENTRY_TYPE_VERTEX_BUFFER_NVX = 3, + VK_OBJECT_ENTRY_TYPE_PUSH_CONSTANT_NVX = 4, + VK_OBJECT_ENTRY_TYPE_BEGIN_RANGE_NVX = VK_OBJECT_ENTRY_TYPE_DESCRIPTOR_SET_NVX, + VK_OBJECT_ENTRY_TYPE_END_RANGE_NVX = VK_OBJECT_ENTRY_TYPE_PUSH_CONSTANT_NVX, + VK_OBJECT_ENTRY_TYPE_RANGE_SIZE_NVX = (VK_OBJECT_ENTRY_TYPE_PUSH_CONSTANT_NVX - VK_OBJECT_ENTRY_TYPE_DESCRIPTOR_SET_NVX + 1), + VK_OBJECT_ENTRY_TYPE_MAX_ENUM_NVX = 0x7FFFFFFF +} VkObjectEntryTypeNVX; + + +typedef enum VkIndirectCommandsLayoutUsageFlagBitsNVX { + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_UNORDERED_SEQUENCES_BIT_NVX = 0x00000001, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_SPARSE_SEQUENCES_BIT_NVX = 0x00000002, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_EMPTY_EXECUTIONS_BIT_NVX = 0x00000004, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_INDEXED_SEQUENCES_BIT_NVX = 0x00000008, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_FLAG_BITS_MAX_ENUM_NVX = 0x7FFFFFFF +} VkIndirectCommandsLayoutUsageFlagBitsNVX; +typedef VkFlags VkIndirectCommandsLayoutUsageFlagsNVX; + +typedef enum VkObjectEntryUsageFlagBitsNVX { + VK_OBJECT_ENTRY_USAGE_GRAPHICS_BIT_NVX = 0x00000001, + VK_OBJECT_ENTRY_USAGE_COMPUTE_BIT_NVX = 0x00000002, + VK_OBJECT_ENTRY_USAGE_FLAG_BITS_MAX_ENUM_NVX = 0x7FFFFFFF +} VkObjectEntryUsageFlagBitsNVX; +typedef VkFlags VkObjectEntryUsageFlagsNVX; + +typedef struct VkDeviceGeneratedCommandsFeaturesNVX { + VkStructureType sType; + const void* pNext; + VkBool32 computeBindingPointSupport; +} VkDeviceGeneratedCommandsFeaturesNVX; + +typedef struct VkDeviceGeneratedCommandsLimitsNVX { + VkStructureType sType; + const void* pNext; + uint32_t maxIndirectCommandsLayoutTokenCount; + uint32_t maxObjectEntryCounts; + uint32_t minSequenceCountBufferOffsetAlignment; + uint32_t minSequenceIndexBufferOffsetAlignment; + uint32_t minCommandsTokenBufferOffsetAlignment; +} VkDeviceGeneratedCommandsLimitsNVX; + +typedef struct VkIndirectCommandsTokenNVX { + VkIndirectCommandsTokenTypeNVX tokenType; + VkBuffer buffer; + VkDeviceSize offset; +} VkIndirectCommandsTokenNVX; + +typedef struct VkIndirectCommandsLayoutTokenNVX { + VkIndirectCommandsTokenTypeNVX tokenType; + uint32_t bindingUnit; + uint32_t dynamicCount; + uint32_t divisor; +} VkIndirectCommandsLayoutTokenNVX; + +typedef struct VkIndirectCommandsLayoutCreateInfoNVX { + VkStructureType sType; + const void* pNext; + VkPipelineBindPoint pipelineBindPoint; + VkIndirectCommandsLayoutUsageFlagsNVX flags; + uint32_t tokenCount; + const VkIndirectCommandsLayoutTokenNVX* pTokens; +} VkIndirectCommandsLayoutCreateInfoNVX; + +typedef struct VkCmdProcessCommandsInfoNVX { + VkStructureType sType; + const void* pNext; + VkObjectTableNVX objectTable; + VkIndirectCommandsLayoutNVX indirectCommandsLayout; + uint32_t indirectCommandsTokenCount; + const VkIndirectCommandsTokenNVX* pIndirectCommandsTokens; + uint32_t maxSequencesCount; + VkCommandBuffer targetCommandBuffer; + VkBuffer sequencesCountBuffer; + VkDeviceSize sequencesCountOffset; + VkBuffer sequencesIndexBuffer; + VkDeviceSize sequencesIndexOffset; +} VkCmdProcessCommandsInfoNVX; + +typedef struct VkCmdReserveSpaceForCommandsInfoNVX { + VkStructureType sType; + const void* pNext; + VkObjectTableNVX objectTable; + VkIndirectCommandsLayoutNVX indirectCommandsLayout; + uint32_t maxSequencesCount; +} VkCmdReserveSpaceForCommandsInfoNVX; + +typedef struct VkObjectTableCreateInfoNVX { + VkStructureType sType; + const void* pNext; + uint32_t objectCount; + const VkObjectEntryTypeNVX* pObjectEntryTypes; + const uint32_t* pObjectEntryCounts; + const VkObjectEntryUsageFlagsNVX* pObjectEntryUsageFlags; + uint32_t maxUniformBuffersPerDescriptor; + uint32_t maxStorageBuffersPerDescriptor; + uint32_t maxStorageImagesPerDescriptor; + uint32_t maxSampledImagesPerDescriptor; + uint32_t maxPipelineLayouts; +} VkObjectTableCreateInfoNVX; + +typedef struct VkObjectTableEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; +} VkObjectTableEntryNVX; + +typedef struct VkObjectTablePipelineEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkPipeline pipeline; +} VkObjectTablePipelineEntryNVX; + +typedef struct VkObjectTableDescriptorSetEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkPipelineLayout pipelineLayout; + VkDescriptorSet descriptorSet; +} VkObjectTableDescriptorSetEntryNVX; + +typedef struct VkObjectTableVertexBufferEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkBuffer buffer; +} VkObjectTableVertexBufferEntryNVX; + +typedef struct VkObjectTableIndexBufferEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkBuffer buffer; + VkIndexType indexType; +} VkObjectTableIndexBufferEntryNVX; + +typedef struct VkObjectTablePushConstantEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkPipelineLayout pipelineLayout; + VkShaderStageFlags stageFlags; +} VkObjectTablePushConstantEntryNVX; + + +typedef void (VKAPI_PTR *PFN_vkCmdProcessCommandsNVX)(VkCommandBuffer commandBuffer, const VkCmdProcessCommandsInfoNVX* pProcessCommandsInfo); +typedef void (VKAPI_PTR *PFN_vkCmdReserveSpaceForCommandsNVX)(VkCommandBuffer commandBuffer, const VkCmdReserveSpaceForCommandsInfoNVX* pReserveSpaceInfo); +typedef VkResult (VKAPI_PTR *PFN_vkCreateIndirectCommandsLayoutNVX)(VkDevice device, const VkIndirectCommandsLayoutCreateInfoNVX* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkIndirectCommandsLayoutNVX* pIndirectCommandsLayout); +typedef void (VKAPI_PTR *PFN_vkDestroyIndirectCommandsLayoutNVX)(VkDevice device, VkIndirectCommandsLayoutNVX indirectCommandsLayout, const VkAllocationCallbacks* pAllocator); +typedef VkResult (VKAPI_PTR *PFN_vkCreateObjectTableNVX)(VkDevice device, const VkObjectTableCreateInfoNVX* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkObjectTableNVX* pObjectTable); +typedef void (VKAPI_PTR *PFN_vkDestroyObjectTableNVX)(VkDevice device, VkObjectTableNVX objectTable, const VkAllocationCallbacks* pAllocator); +typedef VkResult (VKAPI_PTR *PFN_vkRegisterObjectsNVX)(VkDevice device, VkObjectTableNVX objectTable, uint32_t objectCount, const VkObjectTableEntryNVX* const* ppObjectTableEntries, const uint32_t* pObjectIndices); +typedef VkResult (VKAPI_PTR *PFN_vkUnregisterObjectsNVX)(VkDevice device, VkObjectTableNVX objectTable, uint32_t objectCount, const VkObjectEntryTypeNVX* pObjectEntryTypes, const uint32_t* pObjectIndices); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceGeneratedCommandsPropertiesNVX)(VkPhysicalDevice physicalDevice, VkDeviceGeneratedCommandsFeaturesNVX* pFeatures, VkDeviceGeneratedCommandsLimitsNVX* pLimits); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdProcessCommandsNVX( + VkCommandBuffer commandBuffer, + const VkCmdProcessCommandsInfoNVX* pProcessCommandsInfo); + +VKAPI_ATTR void VKAPI_CALL vkCmdReserveSpaceForCommandsNVX( + VkCommandBuffer commandBuffer, + const VkCmdReserveSpaceForCommandsInfoNVX* pReserveSpaceInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkCreateIndirectCommandsLayoutNVX( + VkDevice device, + const VkIndirectCommandsLayoutCreateInfoNVX* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkIndirectCommandsLayoutNVX* pIndirectCommandsLayout); + +VKAPI_ATTR void VKAPI_CALL vkDestroyIndirectCommandsLayoutNVX( + VkDevice device, + VkIndirectCommandsLayoutNVX indirectCommandsLayout, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR VkResult VKAPI_CALL vkCreateObjectTableNVX( + VkDevice device, + const VkObjectTableCreateInfoNVX* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkObjectTableNVX* pObjectTable); + +VKAPI_ATTR void VKAPI_CALL vkDestroyObjectTableNVX( + VkDevice device, + VkObjectTableNVX objectTable, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR VkResult VKAPI_CALL vkRegisterObjectsNVX( + VkDevice device, + VkObjectTableNVX objectTable, + uint32_t objectCount, + const VkObjectTableEntryNVX* const* ppObjectTableEntries, + const uint32_t* pObjectIndices); + +VKAPI_ATTR VkResult VKAPI_CALL vkUnregisterObjectsNVX( + VkDevice device, + VkObjectTableNVX objectTable, + uint32_t objectCount, + const VkObjectEntryTypeNVX* pObjectEntryTypes, + const uint32_t* pObjectIndices); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceGeneratedCommandsPropertiesNVX( + VkPhysicalDevice physicalDevice, + VkDeviceGeneratedCommandsFeaturesNVX* pFeatures, + VkDeviceGeneratedCommandsLimitsNVX* pLimits); +#endif + +#define VK_NV_clip_space_w_scaling 1 +#define VK_NV_CLIP_SPACE_W_SCALING_SPEC_VERSION 1 +#define VK_NV_CLIP_SPACE_W_SCALING_EXTENSION_NAME "VK_NV_clip_space_w_scaling" + +typedef struct VkViewportWScalingNV { + float xcoeff; + float ycoeff; +} VkViewportWScalingNV; + +typedef struct VkPipelineViewportWScalingStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkBool32 viewportWScalingEnable; + uint32_t viewportCount; + const VkViewportWScalingNV* pViewportWScalings; +} VkPipelineViewportWScalingStateCreateInfoNV; + + +typedef void (VKAPI_PTR *PFN_vkCmdSetViewportWScalingNV)(VkCommandBuffer commandBuffer, uint32_t firstViewport, uint32_t viewportCount, const VkViewportWScalingNV* pViewportWScalings); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdSetViewportWScalingNV( + VkCommandBuffer commandBuffer, + uint32_t firstViewport, + uint32_t viewportCount, + const VkViewportWScalingNV* pViewportWScalings); +#endif + +#define VK_EXT_direct_mode_display 1 +#define VK_EXT_DIRECT_MODE_DISPLAY_SPEC_VERSION 1 +#define VK_EXT_DIRECT_MODE_DISPLAY_EXTENSION_NAME "VK_EXT_direct_mode_display" + +typedef VkResult (VKAPI_PTR *PFN_vkReleaseDisplayEXT)(VkPhysicalDevice physicalDevice, VkDisplayKHR display); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkReleaseDisplayEXT( + VkPhysicalDevice physicalDevice, + VkDisplayKHR display); +#endif + +#ifdef VK_USE_PLATFORM_XLIB_XRANDR_EXT +#define VK_EXT_acquire_xlib_display 1 +#include + +#define VK_EXT_ACQUIRE_XLIB_DISPLAY_SPEC_VERSION 1 +#define VK_EXT_ACQUIRE_XLIB_DISPLAY_EXTENSION_NAME "VK_EXT_acquire_xlib_display" + +typedef VkResult (VKAPI_PTR *PFN_vkAcquireXlibDisplayEXT)(VkPhysicalDevice physicalDevice, Display* dpy, VkDisplayKHR display); +typedef VkResult (VKAPI_PTR *PFN_vkGetRandROutputDisplayEXT)(VkPhysicalDevice physicalDevice, Display* dpy, RROutput rrOutput, VkDisplayKHR* pDisplay); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkAcquireXlibDisplayEXT( + VkPhysicalDevice physicalDevice, + Display* dpy, + VkDisplayKHR display); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetRandROutputDisplayEXT( + VkPhysicalDevice physicalDevice, + Display* dpy, + RROutput rrOutput, + VkDisplayKHR* pDisplay); +#endif +#endif /* VK_USE_PLATFORM_XLIB_XRANDR_EXT */ + +#define VK_EXT_display_surface_counter 1 +#define VK_EXT_DISPLAY_SURFACE_COUNTER_SPEC_VERSION 1 +#define VK_EXT_DISPLAY_SURFACE_COUNTER_EXTENSION_NAME "VK_EXT_display_surface_counter" +#define VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES2_EXT VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES_2_EXT + + +typedef enum VkSurfaceCounterFlagBitsEXT { + VK_SURFACE_COUNTER_VBLANK_EXT = 0x00000001, + VK_SURFACE_COUNTER_FLAG_BITS_MAX_ENUM_EXT = 0x7FFFFFFF +} VkSurfaceCounterFlagBitsEXT; +typedef VkFlags VkSurfaceCounterFlagsEXT; + +typedef struct VkSurfaceCapabilities2EXT { + VkStructureType sType; + void* pNext; + uint32_t minImageCount; + uint32_t maxImageCount; + VkExtent2D currentExtent; + VkExtent2D minImageExtent; + VkExtent2D maxImageExtent; + uint32_t maxImageArrayLayers; + VkSurfaceTransformFlagsKHR supportedTransforms; + VkSurfaceTransformFlagBitsKHR currentTransform; + VkCompositeAlphaFlagsKHR supportedCompositeAlpha; + VkImageUsageFlags supportedUsageFlags; + VkSurfaceCounterFlagsEXT supportedSurfaceCounters; +} VkSurfaceCapabilities2EXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceSurfaceCapabilities2EXT)(VkPhysicalDevice physicalDevice, VkSurfaceKHR surface, VkSurfaceCapabilities2EXT* pSurfaceCapabilities); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceSurfaceCapabilities2EXT( + VkPhysicalDevice physicalDevice, + VkSurfaceKHR surface, + VkSurfaceCapabilities2EXT* pSurfaceCapabilities); +#endif + +#define VK_EXT_display_control 1 +#define VK_EXT_DISPLAY_CONTROL_SPEC_VERSION 1 +#define VK_EXT_DISPLAY_CONTROL_EXTENSION_NAME "VK_EXT_display_control" + + +typedef enum VkDisplayPowerStateEXT { + VK_DISPLAY_POWER_STATE_OFF_EXT = 0, + VK_DISPLAY_POWER_STATE_SUSPEND_EXT = 1, + VK_DISPLAY_POWER_STATE_ON_EXT = 2, + VK_DISPLAY_POWER_STATE_BEGIN_RANGE_EXT = VK_DISPLAY_POWER_STATE_OFF_EXT, + VK_DISPLAY_POWER_STATE_END_RANGE_EXT = VK_DISPLAY_POWER_STATE_ON_EXT, + VK_DISPLAY_POWER_STATE_RANGE_SIZE_EXT = (VK_DISPLAY_POWER_STATE_ON_EXT - VK_DISPLAY_POWER_STATE_OFF_EXT + 1), + VK_DISPLAY_POWER_STATE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDisplayPowerStateEXT; + +typedef enum VkDeviceEventTypeEXT { + VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT = 0, + VK_DEVICE_EVENT_TYPE_BEGIN_RANGE_EXT = VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT, + VK_DEVICE_EVENT_TYPE_END_RANGE_EXT = VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT, + VK_DEVICE_EVENT_TYPE_RANGE_SIZE_EXT = (VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT - VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT + 1), + VK_DEVICE_EVENT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDeviceEventTypeEXT; + +typedef enum VkDisplayEventTypeEXT { + VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT = 0, + VK_DISPLAY_EVENT_TYPE_BEGIN_RANGE_EXT = VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT, + VK_DISPLAY_EVENT_TYPE_END_RANGE_EXT = VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT, + VK_DISPLAY_EVENT_TYPE_RANGE_SIZE_EXT = (VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT - VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT + 1), + VK_DISPLAY_EVENT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDisplayEventTypeEXT; + +typedef struct VkDisplayPowerInfoEXT { + VkStructureType sType; + const void* pNext; + VkDisplayPowerStateEXT powerState; +} VkDisplayPowerInfoEXT; + +typedef struct VkDeviceEventInfoEXT { + VkStructureType sType; + const void* pNext; + VkDeviceEventTypeEXT deviceEvent; +} VkDeviceEventInfoEXT; + +typedef struct VkDisplayEventInfoEXT { + VkStructureType sType; + const void* pNext; + VkDisplayEventTypeEXT displayEvent; +} VkDisplayEventInfoEXT; + +typedef struct VkSwapchainCounterCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkSurfaceCounterFlagsEXT surfaceCounters; +} VkSwapchainCounterCreateInfoEXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkDisplayPowerControlEXT)(VkDevice device, VkDisplayKHR display, const VkDisplayPowerInfoEXT* pDisplayPowerInfo); +typedef VkResult (VKAPI_PTR *PFN_vkRegisterDeviceEventEXT)(VkDevice device, const VkDeviceEventInfoEXT* pDeviceEventInfo, const VkAllocationCallbacks* pAllocator, VkFence* pFence); +typedef VkResult (VKAPI_PTR *PFN_vkRegisterDisplayEventEXT)(VkDevice device, VkDisplayKHR display, const VkDisplayEventInfoEXT* pDisplayEventInfo, const VkAllocationCallbacks* pAllocator, VkFence* pFence); +typedef VkResult (VKAPI_PTR *PFN_vkGetSwapchainCounterEXT)(VkDevice device, VkSwapchainKHR swapchain, VkSurfaceCounterFlagBitsEXT counter, uint64_t* pCounterValue); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkDisplayPowerControlEXT( + VkDevice device, + VkDisplayKHR display, + const VkDisplayPowerInfoEXT* pDisplayPowerInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkRegisterDeviceEventEXT( + VkDevice device, + const VkDeviceEventInfoEXT* pDeviceEventInfo, + const VkAllocationCallbacks* pAllocator, + VkFence* pFence); + +VKAPI_ATTR VkResult VKAPI_CALL vkRegisterDisplayEventEXT( + VkDevice device, + VkDisplayKHR display, + const VkDisplayEventInfoEXT* pDisplayEventInfo, + const VkAllocationCallbacks* pAllocator, + VkFence* pFence); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetSwapchainCounterEXT( + VkDevice device, + VkSwapchainKHR swapchain, + VkSurfaceCounterFlagBitsEXT counter, + uint64_t* pCounterValue); +#endif + +#define VK_GOOGLE_display_timing 1 +#define VK_GOOGLE_DISPLAY_TIMING_SPEC_VERSION 1 +#define VK_GOOGLE_DISPLAY_TIMING_EXTENSION_NAME "VK_GOOGLE_display_timing" + +typedef struct VkRefreshCycleDurationGOOGLE { + uint64_t refreshDuration; +} VkRefreshCycleDurationGOOGLE; + +typedef struct VkPastPresentationTimingGOOGLE { + uint32_t presentID; + uint64_t desiredPresentTime; + uint64_t actualPresentTime; + uint64_t earliestPresentTime; + uint64_t presentMargin; +} VkPastPresentationTimingGOOGLE; + +typedef struct VkPresentTimeGOOGLE { + uint32_t presentID; + uint64_t desiredPresentTime; +} VkPresentTimeGOOGLE; + +typedef struct VkPresentTimesInfoGOOGLE { + VkStructureType sType; + const void* pNext; + uint32_t swapchainCount; + const VkPresentTimeGOOGLE* pTimes; +} VkPresentTimesInfoGOOGLE; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetRefreshCycleDurationGOOGLE)(VkDevice device, VkSwapchainKHR swapchain, VkRefreshCycleDurationGOOGLE* pDisplayTimingProperties); +typedef VkResult (VKAPI_PTR *PFN_vkGetPastPresentationTimingGOOGLE)(VkDevice device, VkSwapchainKHR swapchain, uint32_t* pPresentationTimingCount, VkPastPresentationTimingGOOGLE* pPresentationTimings); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetRefreshCycleDurationGOOGLE( + VkDevice device, + VkSwapchainKHR swapchain, + VkRefreshCycleDurationGOOGLE* pDisplayTimingProperties); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPastPresentationTimingGOOGLE( + VkDevice device, + VkSwapchainKHR swapchain, + uint32_t* pPresentationTimingCount, + VkPastPresentationTimingGOOGLE* pPresentationTimings); +#endif + +#define VK_NV_sample_mask_override_coverage 1 +#define VK_NV_SAMPLE_MASK_OVERRIDE_COVERAGE_SPEC_VERSION 1 +#define VK_NV_SAMPLE_MASK_OVERRIDE_COVERAGE_EXTENSION_NAME "VK_NV_sample_mask_override_coverage" + + +#define VK_NV_geometry_shader_passthrough 1 +#define VK_NV_GEOMETRY_SHADER_PASSTHROUGH_SPEC_VERSION 1 +#define VK_NV_GEOMETRY_SHADER_PASSTHROUGH_EXTENSION_NAME "VK_NV_geometry_shader_passthrough" + + +#define VK_NV_viewport_array2 1 +#define VK_NV_VIEWPORT_ARRAY2_SPEC_VERSION 1 +#define VK_NV_VIEWPORT_ARRAY2_EXTENSION_NAME "VK_NV_viewport_array2" + + +#define VK_NVX_multiview_per_view_attributes 1 +#define VK_NVX_MULTIVIEW_PER_VIEW_ATTRIBUTES_SPEC_VERSION 1 +#define VK_NVX_MULTIVIEW_PER_VIEW_ATTRIBUTES_EXTENSION_NAME "VK_NVX_multiview_per_view_attributes" + +typedef struct VkPhysicalDeviceMultiviewPerViewAttributesPropertiesNVX { + VkStructureType sType; + void* pNext; + VkBool32 perViewPositionAllComponents; +} VkPhysicalDeviceMultiviewPerViewAttributesPropertiesNVX; + + + +#define VK_NV_viewport_swizzle 1 +#define VK_NV_VIEWPORT_SWIZZLE_SPEC_VERSION 1 +#define VK_NV_VIEWPORT_SWIZZLE_EXTENSION_NAME "VK_NV_viewport_swizzle" + + +typedef enum VkViewportCoordinateSwizzleNV { + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_X_NV = 0, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_X_NV = 1, + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_Y_NV = 2, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_Y_NV = 3, + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_Z_NV = 4, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_Z_NV = 5, + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_W_NV = 6, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_W_NV = 7, + VK_VIEWPORT_COORDINATE_SWIZZLE_BEGIN_RANGE_NV = VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_X_NV, + VK_VIEWPORT_COORDINATE_SWIZZLE_END_RANGE_NV = VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_W_NV, + VK_VIEWPORT_COORDINATE_SWIZZLE_RANGE_SIZE_NV = (VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_W_NV - VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_X_NV + 1), + VK_VIEWPORT_COORDINATE_SWIZZLE_MAX_ENUM_NV = 0x7FFFFFFF +} VkViewportCoordinateSwizzleNV; + +typedef VkFlags VkPipelineViewportSwizzleStateCreateFlagsNV; + +typedef struct VkViewportSwizzleNV { + VkViewportCoordinateSwizzleNV x; + VkViewportCoordinateSwizzleNV y; + VkViewportCoordinateSwizzleNV z; + VkViewportCoordinateSwizzleNV w; +} VkViewportSwizzleNV; + +typedef struct VkPipelineViewportSwizzleStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkPipelineViewportSwizzleStateCreateFlagsNV flags; + uint32_t viewportCount; + const VkViewportSwizzleNV* pViewportSwizzles; +} VkPipelineViewportSwizzleStateCreateInfoNV; + + + +#define VK_EXT_discard_rectangles 1 +#define VK_EXT_DISCARD_RECTANGLES_SPEC_VERSION 1 +#define VK_EXT_DISCARD_RECTANGLES_EXTENSION_NAME "VK_EXT_discard_rectangles" + + +typedef enum VkDiscardRectangleModeEXT { + VK_DISCARD_RECTANGLE_MODE_INCLUSIVE_EXT = 0, + VK_DISCARD_RECTANGLE_MODE_EXCLUSIVE_EXT = 1, + VK_DISCARD_RECTANGLE_MODE_BEGIN_RANGE_EXT = VK_DISCARD_RECTANGLE_MODE_INCLUSIVE_EXT, + VK_DISCARD_RECTANGLE_MODE_END_RANGE_EXT = VK_DISCARD_RECTANGLE_MODE_EXCLUSIVE_EXT, + VK_DISCARD_RECTANGLE_MODE_RANGE_SIZE_EXT = (VK_DISCARD_RECTANGLE_MODE_EXCLUSIVE_EXT - VK_DISCARD_RECTANGLE_MODE_INCLUSIVE_EXT + 1), + VK_DISCARD_RECTANGLE_MODE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDiscardRectangleModeEXT; + +typedef VkFlags VkPipelineDiscardRectangleStateCreateFlagsEXT; + +typedef struct VkPhysicalDeviceDiscardRectanglePropertiesEXT { + VkStructureType sType; + void* pNext; + uint32_t maxDiscardRectangles; +} VkPhysicalDeviceDiscardRectanglePropertiesEXT; + +typedef struct VkPipelineDiscardRectangleStateCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkPipelineDiscardRectangleStateCreateFlagsEXT flags; + VkDiscardRectangleModeEXT discardRectangleMode; + uint32_t discardRectangleCount; + const VkRect2D* pDiscardRectangles; +} VkPipelineDiscardRectangleStateCreateInfoEXT; + + +typedef void (VKAPI_PTR *PFN_vkCmdSetDiscardRectangleEXT)(VkCommandBuffer commandBuffer, uint32_t firstDiscardRectangle, uint32_t discardRectangleCount, const VkRect2D* pDiscardRectangles); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdSetDiscardRectangleEXT( + VkCommandBuffer commandBuffer, + uint32_t firstDiscardRectangle, + uint32_t discardRectangleCount, + const VkRect2D* pDiscardRectangles); +#endif + +#define VK_EXT_swapchain_colorspace 1 +#define VK_EXT_SWAPCHAIN_COLOR_SPACE_SPEC_VERSION 3 +#define VK_EXT_SWAPCHAIN_COLOR_SPACE_EXTENSION_NAME "VK_EXT_swapchain_colorspace" + + +#define VK_EXT_hdr_metadata 1 +#define VK_EXT_HDR_METADATA_SPEC_VERSION 1 +#define VK_EXT_HDR_METADATA_EXTENSION_NAME "VK_EXT_hdr_metadata" + +typedef struct VkXYColorEXT { + float x; + float y; +} VkXYColorEXT; + +typedef struct VkHdrMetadataEXT { + VkStructureType sType; + const void* pNext; + VkXYColorEXT displayPrimaryRed; + VkXYColorEXT displayPrimaryGreen; + VkXYColorEXT displayPrimaryBlue; + VkXYColorEXT whitePoint; + float maxLuminance; + float minLuminance; + float maxContentLightLevel; + float maxFrameAverageLightLevel; +} VkHdrMetadataEXT; + + +typedef void (VKAPI_PTR *PFN_vkSetHdrMetadataEXT)(VkDevice device, uint32_t swapchainCount, const VkSwapchainKHR* pSwapchains, const VkHdrMetadataEXT* pMetadata); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkSetHdrMetadataEXT( + VkDevice device, + uint32_t swapchainCount, + const VkSwapchainKHR* pSwapchains, + const VkHdrMetadataEXT* pMetadata); +#endif + +#ifdef VK_USE_PLATFORM_IOS_MVK +#define VK_MVK_ios_surface 1 +#define VK_MVK_IOS_SURFACE_SPEC_VERSION 2 +#define VK_MVK_IOS_SURFACE_EXTENSION_NAME "VK_MVK_ios_surface" + +typedef VkFlags VkIOSSurfaceCreateFlagsMVK; + +typedef struct VkIOSSurfaceCreateInfoMVK { + VkStructureType sType; + const void* pNext; + VkIOSSurfaceCreateFlagsMVK flags; + const void* pView; +} VkIOSSurfaceCreateInfoMVK; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateIOSSurfaceMVK)(VkInstance instance, const VkIOSSurfaceCreateInfoMVK* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkSurfaceKHR* pSurface); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateIOSSurfaceMVK( + VkInstance instance, + const VkIOSSurfaceCreateInfoMVK* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkSurfaceKHR* pSurface); +#endif +#endif /* VK_USE_PLATFORM_IOS_MVK */ + +#ifdef VK_USE_PLATFORM_MACOS_MVK +#define VK_MVK_macos_surface 1 +#define VK_MVK_MACOS_SURFACE_SPEC_VERSION 2 +#define VK_MVK_MACOS_SURFACE_EXTENSION_NAME "VK_MVK_macos_surface" + +typedef VkFlags VkMacOSSurfaceCreateFlagsMVK; + +typedef struct VkMacOSSurfaceCreateInfoMVK { + VkStructureType sType; + const void* pNext; + VkMacOSSurfaceCreateFlagsMVK flags; + const void* pView; +} VkMacOSSurfaceCreateInfoMVK; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateMacOSSurfaceMVK)(VkInstance instance, const VkMacOSSurfaceCreateInfoMVK* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkSurfaceKHR* pSurface); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateMacOSSurfaceMVK( + VkInstance instance, + const VkMacOSSurfaceCreateInfoMVK* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkSurfaceKHR* pSurface); +#endif +#endif /* VK_USE_PLATFORM_MACOS_MVK */ + +#define VK_EXT_sampler_filter_minmax 1 +#define VK_EXT_SAMPLER_FILTER_MINMAX_SPEC_VERSION 1 +#define VK_EXT_SAMPLER_FILTER_MINMAX_EXTENSION_NAME "VK_EXT_sampler_filter_minmax" + + +typedef enum VkSamplerReductionModeEXT { + VK_SAMPLER_REDUCTION_MODE_WEIGHTED_AVERAGE_EXT = 0, + VK_SAMPLER_REDUCTION_MODE_MIN_EXT = 1, + VK_SAMPLER_REDUCTION_MODE_MAX_EXT = 2, + VK_SAMPLER_REDUCTION_MODE_BEGIN_RANGE_EXT = VK_SAMPLER_REDUCTION_MODE_WEIGHTED_AVERAGE_EXT, + VK_SAMPLER_REDUCTION_MODE_END_RANGE_EXT = VK_SAMPLER_REDUCTION_MODE_MAX_EXT, + VK_SAMPLER_REDUCTION_MODE_RANGE_SIZE_EXT = (VK_SAMPLER_REDUCTION_MODE_MAX_EXT - VK_SAMPLER_REDUCTION_MODE_WEIGHTED_AVERAGE_EXT + 1), + VK_SAMPLER_REDUCTION_MODE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkSamplerReductionModeEXT; + +typedef struct VkSamplerReductionModeCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkSamplerReductionModeEXT reductionMode; +} VkSamplerReductionModeCreateInfoEXT; + +typedef struct VkPhysicalDeviceSamplerFilterMinmaxPropertiesEXT { + VkStructureType sType; + void* pNext; + VkBool32 filterMinmaxSingleComponentFormats; + VkBool32 filterMinmaxImageComponentMapping; +} VkPhysicalDeviceSamplerFilterMinmaxPropertiesEXT; + + + +#define VK_AMD_gpu_shader_int16 1 +#define VK_AMD_GPU_SHADER_INT16_SPEC_VERSION 1 +#define VK_AMD_GPU_SHADER_INT16_EXTENSION_NAME "VK_AMD_gpu_shader_int16" + + +#define VK_AMD_mixed_attachment_samples 1 +#define VK_AMD_MIXED_ATTACHMENT_SAMPLES_SPEC_VERSION 1 +#define VK_AMD_MIXED_ATTACHMENT_SAMPLES_EXTENSION_NAME "VK_AMD_mixed_attachment_samples" + + +#define VK_EXT_shader_stencil_export 1 +#define VK_EXT_SHADER_STENCIL_EXPORT_SPEC_VERSION 1 +#define VK_EXT_SHADER_STENCIL_EXPORT_EXTENSION_NAME "VK_EXT_shader_stencil_export" + + +#define VK_EXT_blend_operation_advanced 1 +#define VK_EXT_BLEND_OPERATION_ADVANCED_SPEC_VERSION 2 +#define VK_EXT_BLEND_OPERATION_ADVANCED_EXTENSION_NAME "VK_EXT_blend_operation_advanced" + + +typedef enum VkBlendOverlapEXT { + VK_BLEND_OVERLAP_UNCORRELATED_EXT = 0, + VK_BLEND_OVERLAP_DISJOINT_EXT = 1, + VK_BLEND_OVERLAP_CONJOINT_EXT = 2, + VK_BLEND_OVERLAP_BEGIN_RANGE_EXT = VK_BLEND_OVERLAP_UNCORRELATED_EXT, + VK_BLEND_OVERLAP_END_RANGE_EXT = VK_BLEND_OVERLAP_CONJOINT_EXT, + VK_BLEND_OVERLAP_RANGE_SIZE_EXT = (VK_BLEND_OVERLAP_CONJOINT_EXT - VK_BLEND_OVERLAP_UNCORRELATED_EXT + 1), + VK_BLEND_OVERLAP_MAX_ENUM_EXT = 0x7FFFFFFF +} VkBlendOverlapEXT; + +typedef struct VkPhysicalDeviceBlendOperationAdvancedFeaturesEXT { + VkStructureType sType; + void* pNext; + VkBool32 advancedBlendCoherentOperations; +} VkPhysicalDeviceBlendOperationAdvancedFeaturesEXT; + +typedef struct VkPhysicalDeviceBlendOperationAdvancedPropertiesEXT { + VkStructureType sType; + void* pNext; + uint32_t advancedBlendMaxColorAttachments; + VkBool32 advancedBlendIndependentBlend; + VkBool32 advancedBlendNonPremultipliedSrcColor; + VkBool32 advancedBlendNonPremultipliedDstColor; + VkBool32 advancedBlendCorrelatedOverlap; + VkBool32 advancedBlendAllOperations; +} VkPhysicalDeviceBlendOperationAdvancedPropertiesEXT; + +typedef struct VkPipelineColorBlendAdvancedStateCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkBool32 srcPremultiplied; + VkBool32 dstPremultiplied; + VkBlendOverlapEXT blendOverlap; +} VkPipelineColorBlendAdvancedStateCreateInfoEXT; + + + +#define VK_NV_fragment_coverage_to_color 1 +#define VK_NV_FRAGMENT_COVERAGE_TO_COLOR_SPEC_VERSION 1 +#define VK_NV_FRAGMENT_COVERAGE_TO_COLOR_EXTENSION_NAME "VK_NV_fragment_coverage_to_color" + +typedef VkFlags VkPipelineCoverageToColorStateCreateFlagsNV; + +typedef struct VkPipelineCoverageToColorStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkPipelineCoverageToColorStateCreateFlagsNV flags; + VkBool32 coverageToColorEnable; + uint32_t coverageToColorLocation; +} VkPipelineCoverageToColorStateCreateInfoNV; + + + +#define VK_NV_framebuffer_mixed_samples 1 +#define VK_NV_FRAMEBUFFER_MIXED_SAMPLES_SPEC_VERSION 1 +#define VK_NV_FRAMEBUFFER_MIXED_SAMPLES_EXTENSION_NAME "VK_NV_framebuffer_mixed_samples" + + +typedef enum VkCoverageModulationModeNV { + VK_COVERAGE_MODULATION_MODE_NONE_NV = 0, + VK_COVERAGE_MODULATION_MODE_RGB_NV = 1, + VK_COVERAGE_MODULATION_MODE_ALPHA_NV = 2, + VK_COVERAGE_MODULATION_MODE_RGBA_NV = 3, + VK_COVERAGE_MODULATION_MODE_BEGIN_RANGE_NV = VK_COVERAGE_MODULATION_MODE_NONE_NV, + VK_COVERAGE_MODULATION_MODE_END_RANGE_NV = VK_COVERAGE_MODULATION_MODE_RGBA_NV, + VK_COVERAGE_MODULATION_MODE_RANGE_SIZE_NV = (VK_COVERAGE_MODULATION_MODE_RGBA_NV - VK_COVERAGE_MODULATION_MODE_NONE_NV + 1), + VK_COVERAGE_MODULATION_MODE_MAX_ENUM_NV = 0x7FFFFFFF +} VkCoverageModulationModeNV; + +typedef VkFlags VkPipelineCoverageModulationStateCreateFlagsNV; + +typedef struct VkPipelineCoverageModulationStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkPipelineCoverageModulationStateCreateFlagsNV flags; + VkCoverageModulationModeNV coverageModulationMode; + VkBool32 coverageModulationTableEnable; + uint32_t coverageModulationTableCount; + const float* pCoverageModulationTable; +} VkPipelineCoverageModulationStateCreateInfoNV; + + + +#define VK_NV_fill_rectangle 1 +#define VK_NV_FILL_RECTANGLE_SPEC_VERSION 1 +#define VK_NV_FILL_RECTANGLE_EXTENSION_NAME "VK_NV_fill_rectangle" + + +#define VK_EXT_post_depth_coverage 1 +#define VK_EXT_POST_DEPTH_COVERAGE_SPEC_VERSION 1 +#define VK_EXT_POST_DEPTH_COVERAGE_EXTENSION_NAME "VK_EXT_post_depth_coverage" + +#define VK_EXT_shader_viewport_index_layer 1 +#define VK_EXT_SHADER_VIEWPORT_INDEX_LAYER_SPEC_VERSION 1 +#define VK_EXT_SHADER_VIEWPORT_INDEX_LAYER_EXTENSION_NAME "VK_EXT_shader_viewport_index_layer" #ifdef __cplusplus diff --git a/caffe2/operators/elementwise_ops_utils.cc b/caffe2/operators/elementwise_ops_utils.cc index 5bb6c768ea3e..0f76a1b35aa4 100644 --- a/caffe2/operators/elementwise_ops_utils.cc +++ b/caffe2/operators/elementwise_ops_utils.cc @@ -53,7 +53,10 @@ std::vector ComputeBinaryBroadcastForwardDims( for (; i >= 0 && j >= 0; --k) { const int A_dim = A_dims[i]; const int B_dim = B_dims[j]; - CAFFE_ENFORCE(A_dim == B_dim || A_dim == 1 || B_dim == 1); + CAFFE_ENFORCE( + A_dim == B_dim || A_dim == 1 || B_dim == 1, + "A_dim: ", A_dim , ",B_dim: ", B_dim + ); if (A_dim == 0 || B_dim == 0) { C_dims[k] = 0; } else { diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index 95302ca5ccc4..f9c9b6acf034 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -270,6 +270,30 @@ TEST(BoundShapeInference, LengthsRangeFill) { TensorProto_DataType_INT32); } + +TEST(BoundShapeInference, ConstantFill) { + NetDef net; + net.add_op()->CopyFrom( + CreateOperatorDef("ConstantFill", "", {"X"}, {"Y"}, {})); + ShapeInfoMap shape_map; + BoundShapeSpec spec(20, 1000); + BoundShapeInferencer eng(spec); + shape_map.emplace( + "X", + makeTensorInfo( + {TensorBoundShape_DimType_BATCH, + TensorBoundShape_DimType_CONSTANT}, + {20, 1024})); + eng.InferBoundShapeAndType(net, shape_map, nullptr); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, + "Y", + {TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT}, + {20, 1024}, + TensorProto_DataType_FLOAT); +} + // https://github.com/pytorch/pytorch/issues/40861 TEST(BoundShapeInference, DISABLED_ON_WINDOWS(Reshape)) { NetDef net; diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index c513c1a37b01..8ef5de06b02e 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -322,6 +322,12 @@ void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) { if (it != shape_info_.end()) { it->second.setDimType(std::vector( it->second.shape.dims_size(), TensorBoundShape_DimType_CONSTANT)); + if (op.type() == "ConstantFill" && op.input_size() >= 1) { + auto it_input = shape_info_.find(op.input(0)); + if (it_input != shape_info_.end()) { + it->second.setDimType(it_input->second.getDimType()); + } + } } } diff --git a/caffe2/opt/onnxifi_op.h b/caffe2/opt/onnxifi_op.h index eeb93c51e6f8..ce732f7604bc 100644 --- a/caffe2/opt/onnxifi_op.h +++ b/caffe2/opt/onnxifi_op.h @@ -128,6 +128,10 @@ class OnnxifiOp final : public Operator { adjust_quantized_offset_ = 0; } + LOG(INFO) << "use_onnx_=" << use_onnx_ + << ", use_glow_aot_=" << use_glow_aot_ + << ", use_passed_output_shapes_=" << use_passed_output_shapes_; + if (use_passed_output_shapes_) { // Populate output_shapes_per_bs_ for (int bs = 1; bs < max_batch_size_; ++bs) { @@ -145,6 +149,7 @@ class OnnxifiOp final : public Operator { for (output_idx = 0; output_idx < output_names_.size(); ++output_idx) { auto it = name_to_shape.find(output_names_[output_idx]); + CAFFE_ENFORCE(it != name_to_shape.end()); output_shapes_per_bs_[bs].push_back({}); auto &output_shapes = output_shapes_per_bs_[bs].back(); std::copy(it->second.dims.cbegin(), it->second.dims.cend(), std::back_inserter(output_shapes)); @@ -486,7 +491,7 @@ class OnnxifiOp final : public Operator { std::unordered_map input_shape_info_; // Whether we should use passed output shape hints or do shape inference - bool use_passed_output_shapes_{false}; + const bool use_passed_output_shapes_{false}; // Whether we need to resize outputs or not bool adjust_output_batch_{false}; diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 9ccc662d99a9..8089314c3100 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -506,6 +506,31 @@ OnnxifiTransformer::~OnnxifiTransformer() { } } +bool OnnxifiTransformer::canPassOutputShapeHintsPerBs( + const OperatorDef& op, + const std::unordered_map& shape_hints_per_bs) const { + if (shape_hints_per_bs.empty()) { + return false; + } + + for (int bs = 1; bs < opts_.bound_shape_spec.max_batch_size; ++bs) { + auto shape_hints_search = shape_hints_per_bs.find(bs); + if (shape_hints_search == shape_hints_per_bs.end()) { + return false; + } + const auto& shape_hints = shape_hints_search->second; + + for (int output_idx = 0; output_idx < op.output_size(); ++output_idx) { + auto shape_hint_search = shape_hints.find(op.output(output_idx)); + if (shape_hint_search == shape_hints.end()) { + return false; + } + } + } + + return true; +} + OperatorDef OnnxifiTransformer::buildOnnxifiOp( const std::string& onnx_model_str, const std::unordered_set& initialization_list, @@ -583,31 +608,31 @@ OperatorDef OnnxifiTransformer::buildOnnxifiOp( } } - // Add output size hints for per batch size - AddArgument("use_passed_output_shapes", shape_hints_per_bs.empty() ? 0 : 1, &op); - if (!shape_hints_per_bs.empty()) { - for (int bs = 1; bs < opts_.bound_shape_spec.max_batch_size; ++bs) { - auto it = shape_hints_per_bs.find(bs); - CAFFE_ENFORCE(it != shape_hints_per_bs.end()); - const auto& shape_hints_current_bs = it->second; + // Add output size hints per batch size + if (canPassOutputShapeHintsPerBs(op, shape_hints_per_bs)) { + VLOG(2) << "Passing in output shape hints for batch sizes in [1, " << opts_.bound_shape_spec.max_batch_size << ")"; + AddArgument("use_passed_output_shapes", 1, &op); + for (int bs = 1; bs < opts_.bound_shape_spec.max_batch_size; ++bs) { auto* output_shape_arg = op.add_arg(); output_shape_arg->set_name("output_shapes_bs_" + caffe2::to_string(bs)); auto* output_qshape_arg = op.add_arg(); output_qshape_arg->set_name("output_qshapes_bs_" + caffe2::to_string(bs)); + const auto& shape_hints = shape_hints_per_bs.find(bs)->second; + for (int output_idx = 0; output_idx < op.output_size(); ++output_idx) { const auto& output_name = op.output(output_idx); - auto it_output = shape_hints_current_bs.find(output_name); - if (it_output != shape_hints_current_bs.end()) { - if (!it_output->second.is_quantized) { - output_shape_arg->mutable_tensors()->Add()->CopyFrom(wrapShapeInfoIntoTensorProto(output_name, it_output->second)); - } else { - output_shape_arg->mutable_qtensors()->Add()->CopyFrom(wrapShapeInfoIntoQTensorProto(output_name, it_output->second)); - } + const auto& shape_hint = shape_hints.find(output_name)->second; + if (!shape_hint.is_quantized) { + output_shape_arg->mutable_tensors()->Add()->CopyFrom(wrapShapeInfoIntoTensorProto(output_name, shape_hint)); + } else { + output_shape_arg->mutable_qtensors()->Add()->CopyFrom(wrapShapeInfoIntoQTensorProto(output_name, shape_hint)); } } } + } else { + AddArgument("use_passed_output_shapes", 0, &op); } // Tell Onnxifi op that the model is in onnx or c2 proto format diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index 5836486bfd31..d86f112dd485 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -82,6 +82,12 @@ class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { const ShapeInfoMap& shape_hints_max_bs, const std::unordered_map &shape_hints_per_bs); + // Check that output shape hints are present to ensure we can pass them to + // OnnxifiOp + bool canPassOutputShapeHintsPerBs( + const OperatorDef& op, + const std::unordered_map& shape_hints_per_bs) const; + // We already have all the ops and external inputs and outputs! OperatorDef buildOnnxifiOp( const std::string& onnx_model_str, diff --git a/caffe2/quantization/server/int8_gen_quant_params_min_max.cc b/caffe2/quantization/server/int8_gen_quant_params_min_max.cc new file mode 100644 index 000000000000..76a2bb747242 --- /dev/null +++ b/caffe2/quantization/server/int8_gen_quant_params_min_max.cc @@ -0,0 +1,37 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include "caffe2/quantization/server/int8_gen_quant_params_min_max.h" +#include +#include "caffe2/quantization/server/int8_gen_quant_params.h" + +namespace caffe2 { +using namespace std; +using namespace dnnlowp; + +REGISTER_CPU_OPERATOR( + Int8GenQuantParamsMinMax, + Int8GenQuantParamsMinMaxOp); +OPERATOR_SCHEMA(Int8GenQuantParamsMinMax) + .NumInputs(2, 3) + .NumOutputs(1) + .TensorInferenceFunction([](const OperatorDef& /* def */, + const vector& /* in */) { + vector out(1); + out[0].set_data_type(TensorProto_DataType_FLOAT); + out[0].add_dims(1); + return out; + }) + .Input(0, "min", "The lower bound of the tensor to be quantized.") + .Input(1, "max", "The upper bound of the tensor to be quantized.") + .Input( + 2, + "quant_scheme", + "(Optional) Int8QuantSchemeBlob that specifies the quantization kind and preserve_sparsity options when generating the quant params. We only use preserve_sparsity in this op which is default to be false.") + .Output( + 0, + "quant_param", + "Int8QuantParamsBlob that contains the scale and zero_point info in TensorQuantizationParams type.") + .SetDoc( + R"DOC(Operator wrapper for generating int8 tensor quantization parameters given lower and upper bound of the input tensor)DOC"); + +} // namespace caffe2 diff --git a/caffe2/quantization/server/int8_gen_quant_params_min_max.h b/caffe2/quantization/server/int8_gen_quant_params_min_max.h new file mode 100644 index 000000000000..ada6a46a8dec --- /dev/null +++ b/caffe2/quantization/server/int8_gen_quant_params_min_max.h @@ -0,0 +1,50 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once +#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h" +#include "caffe2/quantization/server/dnnlowp.h" +#include "caffe2/quantization/server/int8_gen_quant_params.h" +#include + + +namespace caffe2 { +using namespace std; +using dnnlowp::TensorQuantizationParams; + +template +class Int8GenQuantParamsMinMaxOp final : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + Int8GenQuantParamsMinMaxOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + bool RunOnDevice() override { + // Generate Int8 quant params based on the input data (last N samples of the + // activations) and the quant scheme + const float min = + OperatorBase::Input(0, CPU).template data()[0]; + const float max = + OperatorBase::Input(1, CPU).template data()[0]; + bool preserve_sparsity = false; + if (InputSize() == 3){ + const auto* quant_scheme = + this->template Input>(2).get(); + preserve_sparsity = quant_scheme->preserve_sparsity_; + } + dnnlowp::QuantizationFactory* qfactory = + dnnlowp::QuantizationFactory::GetDefaultInstance(); + TensorQuantizationParams qparam = qfactory->ChooseQuantizationParams( + min, + max, + 8, + preserve_sparsity); + auto* output_qparam = + this->template Output>(0); + output_qparam->reset( + new Int8QuantParamsBlob(qparam.scale, qparam.zero_point)); + LOG_EVERY_N(INFO, 1) << "scale and bias are " << qparam.scale << "," << qparam.zero_point; + return true; + } + +}; // class Int8GenQuantParamsOp + +} // namespace caffe2 diff --git a/caffe2/quantization/server/int8_gen_quant_params_min_max_test.py b/caffe2/quantization/server/int8_gen_quant_params_min_max_test.py new file mode 100644 index 000000000000..dd27074db5c4 --- /dev/null +++ b/caffe2/quantization/server/int8_gen_quant_params_min_max_test.py @@ -0,0 +1,83 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + + + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +from caffe2.python import core, workspace +from caffe2.quantization.server import dnnlowp_pybind11 +from hypothesis import given, settings + + +class TestInt8GenQuantParamsMinMaxOperator(hu.HypothesisTestCase): + @settings(max_examples=20, deadline=None) + @given( + n=st.integers(10, 10), + m=st.integers(10, 10), + preserve_sparsity=st.booleans(), + rnd_seed=st.integers(1, 5), + **hu.gcs_cpu_only + ) + def test_int8_gen_quant_params_min_max_op( + self, n, m, preserve_sparsity, rnd_seed, gc, dc + ): + X_min = 0 if preserve_sparsity else -77 + X_max = X_min + 255 + np.random.seed(rnd_seed) + X = np.round(np.random.rand(n, m) * (X_max - X_min) + X_min).astype( + np.float32 + ) + # Calculate X_qparam + hist, bin_edges = np.histogram(X.flatten(), bins=2048) + X_qparam = dnnlowp_pybind11.ChooseStaticQuantizationParams( + np.min(X), np.max(X), hist, preserve_sparsity, 8, "MIN_MAX_QUANTIZATION" + ) + + # Build a net to generate X's qparam using the Int8GenQuantParamsMinMax op + workspace.FeedBlob("X", X, device_option=gc) + workspace.FeedBlob("X_min", np.array([np.min(X)]), device_option=gc) + workspace.FeedBlob("X_max", np.array([np.max(X)]), device_option=gc) + dnnlowp_pybind11.CreateInt8QuantSchemeBlob( + "quant_scheme", "MIN_MAX_QUANTIZATION", preserve_sparsity + ) + assert workspace.HasBlob( + "quant_scheme" + ), "Failed to create the quant_scheme blob in current workspace" + + gen_quant_params_net = core.Net("gen_quant_params_min_max") + gen_quant_params_op = core.CreateOperator( + "Int8GenQuantParamsMinMax", + ["X_min", "X_max", "quant_scheme"], + ["quant_param"], + device_option=gc, + ) + gen_quant_params_net.Proto().op.extend([gen_quant_params_op]) + assert workspace.RunNetOnce( + gen_quant_params_net + ), "Failed to run the gen_quant_params net" + scale, zero_point = dnnlowp_pybind11.ObserveInt8QuantParamsBlob("quant_param") + + shapes, types = workspace.InferShapesAndTypes( + [gen_quant_params_net], + blob_dimensions={"X": [n, m], "X_min": [1], "X_max": [1], "quant_scheme": [1]}, + blob_types={"X": core.DataType.FLOAT, "X_min": core.DataType.FLOAT, "X_max": core.DataType.FLOAT, "quant_scheme": core.DataType.STRING} + ) + self.assertEqual(shapes["quant_param"], [1]) + self.assertEqual(types["quant_param"], core.DataType.FLOAT) + + np.testing.assert_equal(scale, X_qparam.scale) + np.testing.assert_equal(zero_point, X_qparam.zero_point) diff --git a/caffe2/serialize/crc_alt.h b/caffe2/serialize/crc_alt.h index be51083fec0e..e7c986ff89fb 100644 --- a/caffe2/serialize/crc_alt.h +++ b/caffe2/serialize/crc_alt.h @@ -680,12 +680,12 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) // put operator for one zero bit in odd odd[0] = Polynomial; // CRC-32 polynomial - for (int i = 1; i < CrcBits; i++) + for (uint32_t i = 1; i < CrcBits; i++) odd[i] = 1 << (i - 1); // put operator for two zero bits in even // same as gf2_matrix_square(even, odd); - for (int i = 0; i < CrcBits; i++) + for (uint32_t i = 0; i < CrcBits; i++) { uint32_t vec = odd[i]; even[i] = 0; @@ -695,7 +695,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) } // put operator for four zero bits in odd // same as gf2_matrix_square(odd, even); - for (int i = 0; i < CrcBits; i++) + for (uint32_t i = 0; i < CrcBits; i++) { uint32_t vec = even[i]; odd[i] = 0; @@ -711,7 +711,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) for (; lengthB > 0; lengthB >>= 1) { // same as gf2_matrix_square(a, b); - for (int i = 0; i < CrcBits; i++) + for (uint32_t i = 0; i < CrcBits; i++) { uint32_t vec = b[i]; a[i] = 0; diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 7928d5e3de86..3d9701274ba3 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -65,7 +65,7 @@ PyTorchStreamReader::PyTorchStreamReader(std::istream* in) } PyTorchStreamReader::PyTorchStreamReader( - std::unique_ptr in) + std::shared_ptr in) : ar_(std::make_unique()), in_(std::move(in)) { init(); } diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 2e841d0ad824..ee7e971344ea 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -156,7 +156,7 @@ class CAFFE2_API PyTorchStreamReader final { public: explicit PyTorchStreamReader(const std::string& file_name); explicit PyTorchStreamReader(std::istream* in); - explicit PyTorchStreamReader(std::unique_ptr in); + explicit PyTorchStreamReader(std::shared_ptr in); // return dataptr, size std::tuple getRecord(const std::string& name); @@ -180,7 +180,7 @@ class CAFFE2_API PyTorchStreamReader final { std::unique_ptr ar_; std::string archive_name_; std::string archive_name_plus_slash_; - std::unique_ptr in_; + std::shared_ptr in_; int64_t version_; }; diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index db02f7a8fb16..a9d2e4f50e45 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -110,6 +110,12 @@ if(INTERN_BUILD_ATEN_OPS) endif(MSVC) endif(CXX_AVX2_FOUND) + if(CXX_VSX_FOUND) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_VSX_CPU_DEFINITION") + LIST(APPEND CPU_CAPABILITY_NAMES "VSX") + LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_VSX_FLAGS}") + endif(CXX_VSX_FOUND) + list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index c0e54450b409..968456c40490 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -46,6 +46,8 @@ endif() # 3. If MSVC_Z7_OVERRIDE is ON, then /Zi and /ZI will be replaced with /Z7 # for Debug and RelWithDebInfo builds if(MSVC) + # skip unwanted includes from windows.h + add_definitions(-DWIN32_LEAN_AND_MEAN) foreach(flag_var CMAKE_C_FLAGS CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL) @@ -195,6 +197,21 @@ elseif(INTERN_USE_EIGEN_BLAS) list(APPEND Caffe2_DEPENDENCY_LIBS eigen_blas) endif() +# ---[ FFTW +set(AT_FFTW_ENABLED 0) +set(USE_FFTW OFF) +if(USE_FFTW OR NOT MKL_FOUND) + find_library(LIBFFTW3 fftw3) + if(LIBFFTW3) + find_path(FFTW3_INCLUDE_DIR NAMES fftw3.h ONLY_CMAKE_FIND_ROOT_PATH) + if(FFTW3_INCLUDE_DIR) + SET(AT_FFTW_ENABLED 1) + SET(USE_FFTW ON) + include_directories(${FFTW3_INCLUDE_DIR}) + endif() + endif() +endif() + # ---[ Dependencies # NNPACK and family (QNNPACK, PYTORCH_QNNPACK, and XNNPACK) can download and # compile their dependencies in isolation as part of their build. These dependencies @@ -1498,8 +1515,6 @@ if(NOT INTERN_BUILD_MOBILE) if(MSVC) # we want to respect the standard, and we are bored of those **** . add_definitions(-D_CRT_SECURE_NO_DEPRECATE=1) - # skip unwanted includes from windows.h - add_definitions(-DWIN32_LEAN_AND_MEAN) list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/wd4819,/wd4503,/wd4190,/wd4244,/wd4251,/wd4275,/wd4522") endif() @@ -1623,6 +1638,7 @@ if(NOT INTERN_BUILD_MOBILE) add_compile_options(-DUSE_GCC_GET_CPUID) endif() + find_package(VSX) # checks VSX find_package(AVX) # checks AVX and AVX2 # we don't set -mavx and -mavx2 flags globally, but only for specific files diff --git a/cmake/External/nnpack.cmake b/cmake/External/nnpack.cmake index 84244dc864c3..b1dcd728e690 100644 --- a/cmake/External/nnpack.cmake +++ b/cmake/External/nnpack.cmake @@ -27,7 +27,7 @@ endif() # (2) Anything but x86, x86-64, ARM, ARM64 - unsupported ############################################################################## if(CMAKE_SYSTEM_PROCESSOR) - if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64|armv5te|armv7-a|armv7l|aarch64)$") + if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64|armv5te|armv7-a|armv7l|arm64|aarch64)$") message(WARNING "NNPACK is not supported on ${CMAKE_SYSTEM_PROCESSOR} processors. " "The only supported architectures are x86, x86-64, ARM, and ARM64. " "Turn this warning off by USE_NNPACK=OFF.") diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index bd68f5f36735..acd00cfa6772 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -41,9 +41,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (OMAP4_TRUE) ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") - EXEC_PROGRAM("/usr/sbin/sysctl -n hw.optional.arm64" OUTPUT_VARIABLE - IS_ARM64) - IF(IS_ARM64 STREQUAL "1") + IF(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") set(NEON_FOUND true CACHE BOOL "NEON available on ARM64") ENDIF() EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE diff --git a/cmake/Modules/FindVSX.cmake b/cmake/Modules/FindVSX.cmake new file mode 100644 index 000000000000..74691f9240fb --- /dev/null +++ b/cmake/Modules/FindVSX.cmake @@ -0,0 +1,35 @@ + +IF(CMAKE_SYSTEM_NAME MATCHES "Linux") + message("-- ") + EXEC_PROGRAM(LD_SHOW_AUXV=1 ARGS "/bin/true" OUTPUT_VARIABLE bintrue) + if(bintrue MATCHES "AT_PLATFORM:[ \\t\\n\\r]*([a-zA-Z0-9_]+)[ \\t\\n\\r]*") + if(CMAKE_MATCH_COUNT GREATER 0) + string(TOLOWER ${CMAKE_MATCH_1} platform) + if(${platform} MATCHES "^power") + message("-- POWER Platform: ${platform}") + SET(POWER_COMP TRUE CACHE BOOL "power ") + SET(CXX_VSX_FLAGS "${CXX_VSX_FLAGS} -mcpu=${platform} -mtune=${platform}" ) + endif() + endif() + endif() + SET(VSX_CODE " #include + int main() { + float __attribute__((aligned(16))) vptr_y[8] = { 1.0f,2.f,3.f,4.f,4.f,3.f,2.f,1.f }; + __vector float v_result = vec_add(vec_vsx_ld(0, vptr_y), vec_vsx_ld(16, vptr_y)); + return 0; + }") + #check_cxx_compiler_flag(-mvsx vsx_flag) + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + SET(CMAKE_REQUIRED_FLAGS "-mvsx") + CHECK_C_SOURCE_COMPILES("${VSX_CODE}" C_VSX_FOUND) + CHECK_CXX_SOURCE_COMPILES("${VSX_CODE}" CXX_VSX_FOUND) + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + if(CXX_VSX_FOUND) + message("-- VSX flag was set.") + SET(CXX_VSX_FLAGS "${CXX_VSX_FLAGS} -mvsx" ) + elseif(POWER_COMP) + message(WARNING "-- VSX flag was not set.") + endif() + message("-- ") +endif() + diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 92015c269083..dd9523d1b3fb 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -117,6 +117,7 @@ function(caffe2_print_configuration_summary) endif() message(STATUS " USE_METAL : ${USE_METAL}") message(STATUS " USE_PYTORCH_METAL : ${USE_PYTORCH_METAL}") + message(STATUS " USE_FFTW : ${USE_FFTW}") message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}") message(STATUS " USE_MKLDNN : ${USE_MKLDNN}") if(${CAFFE2_USE_MKLDNN}) diff --git a/docker.Makefile b/docker.Makefile index 3cd59f146e38..3af77ab9c7d1 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -1,31 +1,38 @@ -DOCKER_REGISTRY = docker.io -DOCKER_ORG = $(shell docker info 2>/dev/null | sed '/Username:/!d;s/.* //') -DOCKER_IMAGE = pytorch -DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) +DOCKER_REGISTRY = docker.io +DOCKER_ORG = $(shell docker info 2>/dev/null | sed '/Username:/!d;s/.* //') +DOCKER_IMAGE = pytorch +DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) ifeq ("$(DOCKER_ORG)","") $(warning WARNING: No docker user found using results from whoami) -DOCKER_ORG = $(shell whoami) +DOCKER_ORG = $(shell whoami) endif -CUDA_VERSION = 11.0 -CUDNN_VERSION = 8 -BASE_RUNTIME = ubuntu:18.04 -BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-devel-ubuntu18.04 +CUDA_VERSION = 11.0 +CUDNN_VERSION = 8 +BASE_RUNTIME = ubuntu:18.04 +BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-devel-ubuntu18.04 # The conda channel to use to install pytorch / torchvision -INSTALL_CHANNEL = pytorch +INSTALL_CHANNEL = pytorch -PYTHON_VERSION = 3.7 +PYTHON_VERSION = 3.7 # Can be either official / dev -BUILD_TYPE = dev -BUILD_PROGRESS = auto -BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ - --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ - --build-arg CUDA_VERSION=$(CUDA_VERSION) \ - --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) -DOCKER_BUILD = DOCKER_BUILDKIT=1 docker build --progress=$(BUILD_PROGRESS) --target $(BUILD_TYPE) -t $(DOCKER_FULL_NAME):$(DOCKER_TAG) $(BUILD_ARGS) . -DOCKER_PUSH = docker push $(DOCKER_FULL_NAME):$(DOCKER_TAG) +BUILD_TYPE = dev +BUILD_PROGRESS = auto +BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ + --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ + --build-arg CUDA_VERSION=$(CUDA_VERSION) \ + --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) +EXTRA_DOCKER_BUILD_FLAGS ?= +DOCKER_BUILD = DOCKER_BUILDKIT=1 \ + docker build \ + --progress=$(BUILD_PROGRESS) \ + $(EXTRA_DOCKER_BUILD_FLAGS) \ + --target $(BUILD_TYPE) \ + -t $(DOCKER_FULL_NAME):$(DOCKER_TAG) \ + $(BUILD_ARGS) . +DOCKER_PUSH = docker push $(DOCKER_FULL_NAME):$(DOCKER_TAG) .PHONY: all all: devel-image diff --git a/docs/source/conf.py b/docs/source/conf.py index fe1e2260be72..610f6efa0840 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -161,7 +161,7 @@ # TODO: verify this works as expected release = 'master' -# Customized html_title here. +# Customized html_title here. # Default is " ".join(project, release, "documentation") if not set if RELEASE: # remove hash (start with 'a') from version number if any @@ -192,6 +192,9 @@ # Disable docstring inheritance autodoc_inherit_docstrings = False +# Disable displaying type annotations, these can be very verbose +autodoc_typehints = 'none' + # -- katex javascript in header # @@ -253,9 +256,9 @@ def setup(app): add_css(css_file) # From PyTorch 1.5, we now use autogenerated files to document classes and -# functions. This breaks older references since +# functions. This breaks older references since # https://docs.pytorch.org/torch.html#torch.flip -# moved to +# moved to # https://docs.pytorch.org/torch/generated/torchflip.html # which breaks older links from blog posts, stack overflow answers and more. # To mitigate that, we add an id="torch.flip" in an appropriated place @@ -278,7 +281,7 @@ def visit_reference(self, node): # to autogenerated content anchor = ref_anchor[1] txt = node.parent.astext() - if txt == anchor or txt == anchor.split('.')[-1]: + if txt == anchor or txt == anchor.split('.')[-1]: self.body.append('

'.format(ref_anchor[1])) return old_call(self, node) Klass.visit_reference = visit_reference diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index f5bce396054b..b35a34fc0265 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -384,16 +384,24 @@ Collective functions .. autofunction:: broadcast +.. autofunction:: broadcast_object_list + .. autofunction:: all_reduce .. autofunction:: reduce .. autofunction:: all_gather +.. autofunction:: all_gather_object + .. autofunction:: gather +.. autofunction:: gather_object + .. autofunction:: scatter +.. autofunction:: scatter_object_list + .. autofunction:: reduce_scatter .. autofunction:: all_to_all diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index b5d78572c06b..a3bee886f062 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -13,6 +13,7 @@ Functions --------- .. autofunction:: cholesky +.. autofunction:: cond .. autofunction:: det .. autofunction:: eigh .. autofunction:: eigvalsh diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 9dc107d86267..49bbc1df45a0 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -274,6 +274,71 @@ In addition, Dropout layer need defined in init function so that inferencing can def forward(self, x): x = self.dropout(x) +Using dictionaries to handle Named Arguments as model inputs +------------------------------------------------------------ + +There are two ways to handle models which consist of named parameters or keyword arguments as inputs: + +* The first method is to pass all the inputs in the same order as required by the model and pass None + values for the keyword arguments that do not require a value to be passed + +* The second and more intuitive method is to represent the keyword arguments as key-value pairs where + the key represents the name of the argument in the model signature and the value represents the value + of the argument to be passed + +For example, in the model: :: + + class Model(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + m = Model() + x = torch.randn(2, 3) + z = torch.randn(2, 3) + +There are two ways of exporting the model: + +* Not using a dictionary for the keyword arguments and passing all the inputs in the same order + as required by the model :: + + torch.onnx.export(model, (x, None, z), ‘test.onnx’) + +* Using a dictionary to represent the keyword arguments. This dictionary is always passed in + addition to the non-keyword arguments and is always the last argument in the args tuple. :: + + torch.onnx.export(model, (x, {'y': None, 'z': z}), ‘test.onnx’) + +For cases in which there are no keyword arguments, models can be exported with either an +empty or no dictionary. For example, :: + + torch.onnx.export(model, (x, {}), ‘test.onnx’) + or + torch.onnx.export(model, (x, ), ‘test.onnx’) + +An exception to this rule are cases in which the last input is also of a dictionary type. +In these cases it is mandatory to have an empty dictionary as the last argument in the +args tuple. For example, :: + + class Model(torch.nn.Module): + def forward(self, k, x): + ... + return x + m = Model() + k = torch.randn(2, 3)   + x = {torch.tensor(1.): torch.randn(2, 3)} + +Without the presence of the empty dictionary, the export call assumes that the +‘x’ input is intended to represent the optional dictionary consisting of named arguments. +In order to prevent this from being an issue a constraint is placed to provide an empty +dictionary as the last input in the tuple args in such cases. +The new call would look like this. :: + + torch.onnx.export(model, (k, x, {}), ‘test.onnx’) + + Indexing -------- diff --git a/ios/LibTorch.podspec b/ios/LibTorch.podspec index 236f1de7988f..b90cf6aff5d6 100644 --- a/ios/LibTorch.podspec +++ b/ios/LibTorch.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch' - s.version = '1.7.0' + s.version = '1.7.1' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' diff --git a/modules/detectron/group_spatial_softmax_op.cu b/modules/detectron/group_spatial_softmax_op.cu index 92e89ae5acc2..a37a3fba55a7 100644 --- a/modules/detectron/group_spatial_softmax_op.cu +++ b/modules/detectron/group_spatial_softmax_op.cu @@ -112,6 +112,7 @@ bool GroupSpatialSoftmaxOp::RunOnDevice() { GroupSpatialSoftmaxKernel<<>>( N, A, W, H, Xdata, Pdata, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -158,11 +159,13 @@ bool GroupSpatialSoftmaxGradientOp::RunOnDevice() { SumProbsKernel<<>>( N, A, W, H, Ydata, dYdata, sum_probs_data, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Step 2: dX[i] = dX[i] - s SubSumKernel<<>>( N, A, W, H, sum_probs_.data(), dXdata, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Step 3: dX[i] = Y[i] * dX[i] math::Mul(Y.size(), dXdata, Ydata, dXdata, &context_); diff --git a/modules/detectron/ps_roi_pool_op.cu b/modules/detectron/ps_roi_pool_op.cu index 1ba418be5c99..68e4ec377d62 100644 --- a/modules/detectron/ps_roi_pool_op.cu +++ b/modules/detectron/ps_roi_pool_op.cu @@ -253,6 +253,7 @@ bool PSRoIPoolOp::RunOnDevice() { output_size, X.data(), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, R.data(), output_dim_, group_size_, Y->mutable_data(), A->mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -276,6 +277,7 @@ bool PSRoIPoolGradientOp::RunOnDevice() { dY.size(), dY.data(), A.data(), R.dim32(0), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, output_dim_, dX->mutable_data(), R.data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/detectron/roi_pool_f_op.cu b/modules/detectron/roi_pool_f_op.cu index 62948f7eacbe..b261911b95a1 100644 --- a/modules/detectron/roi_pool_f_op.cu +++ b/modules/detectron/roi_pool_f_op.cu @@ -149,6 +149,7 @@ bool RoIPoolFOp::RunOnDevice() { output_size, X.data(), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, R.data(), Y->mutable_data(), A->mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -173,6 +174,7 @@ bool RoIPoolFGradientOp::RunOnDevice() { dY.size(), dY.data(), A.data(), R.dim32(0), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, dX->mutable_data(), R.data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; } diff --git a/modules/detectron/select_smooth_l1_loss_op.cu b/modules/detectron/select_smooth_l1_loss_op.cu index 9065bfc7afbe..ce68fcff634d 100644 --- a/modules/detectron/select_smooth_l1_loss_op.cu +++ b/modules/detectron/select_smooth_l1_loss_op.cu @@ -129,6 +129,7 @@ bool SelectSmoothL1LossOp::RunOnDevice() { M, Y_hat.data(), Y.data(), L.data(), buff_.mutable_data(), S.data(), beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Sum of all losses // al := sum_i l_i @@ -175,6 +176,7 @@ bool SelectSmoothL1LossGradientOp::RunOnDevice() { D, H, W, M, Y_hat.data(), Y.data(), L.data(), d_Y_hat->mutable_data(), d_avg_loss.data(), scale_, S.data(), beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/detectron/sigmoid_cross_entropy_loss_op.cu b/modules/detectron/sigmoid_cross_entropy_loss_op.cu index d69a7b41dc33..bb86560fcb01 100644 --- a/modules/detectron/sigmoid_cross_entropy_loss_op.cu +++ b/modules/detectron/sigmoid_cross_entropy_loss_op.cu @@ -93,6 +93,8 @@ bool SigmoidCrossEntropyLossOp::RunOnDevice() { T.data(), losses_.mutable_data(), counts_.mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + float* avg_loss_data = avg_loss->mutable_data(); math::Sum( losses_.size(), losses_.data(), avg_loss_data, &context_); @@ -106,6 +108,7 @@ bool SigmoidCrossEntropyLossOp::RunOnDevice() { CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(normalizer_.size(), normalizer_data, 1e-5); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Div( 1, avg_loss_data, normalizer_data, avg_loss_data, &context_); } @@ -135,6 +138,7 @@ bool SigmoidCrossEntropyLossGradientOp::RunOnDevice() { T.data(), dX->mutable_data(), counts_.mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); if (normalize_) { float* normalizer_data = normalizer_.mutable_data(); math::Sum( @@ -145,6 +149,7 @@ bool SigmoidCrossEntropyLossGradientOp::RunOnDevice() { CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(normalizer_.size(), normalizer_data, 1e-5); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Div( 1, d_avg_loss.data(), diff --git a/modules/detectron/sigmoid_focal_loss_op.cu b/modules/detectron/sigmoid_focal_loss_op.cu index 5b130c8dfc1f..e6f2dea21b5d 100644 --- a/modules/detectron/sigmoid_focal_loss_op.cu +++ b/modules/detectron/sigmoid_focal_loss_op.cu @@ -134,6 +134,7 @@ bool SigmoidFocalLossOp::RunOnDevice() { N, D, H, W, X.data(), T.data(), wp.data(), gamma_, alpha_, num_classes_, losses_.mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Sum( losses_.size(), losses_.data(), avg_loss_data, &context_); @@ -165,6 +166,7 @@ bool SigmoidFocalLossGradientOp::RunOnDevice() { N, D, H, W, X.data(), T.data(), dX->mutable_data(), wp.data(), gamma_, alpha_, num_classes_, d_avg_loss.data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Scale( dX->size(), scale_, diff --git a/modules/detectron/smooth_l1_loss_op.cu b/modules/detectron/smooth_l1_loss_op.cu index 1a3e8b78b53f..ea835a4bc2b9 100644 --- a/modules/detectron/smooth_l1_loss_op.cu +++ b/modules/detectron/smooth_l1_loss_op.cu @@ -102,6 +102,7 @@ bool SmoothL1LossOp::RunOnDevice() { context_.cuda_stream()>>>( buff_.size(), buff_.data(), buff_.mutable_data(), beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Element-wise weighted smooth l1 loss (can be used to specify a per-element // loss weight) @@ -164,6 +165,8 @@ bool SmoothL1LossGradientOp::RunOnDevice() { context_.cuda_stream()>>>( buff_.size(), buff_.data(), d_Y_hat->mutable_data(), d_avg_loss.data(), scale_ / N, beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Element-wise scale by alpha_in and alpha_out math::Mul( d_Y_hat->size(), d_Y_hat->data(), alpha_in.data(), diff --git a/modules/detectron/softmax_focal_loss_op.cu b/modules/detectron/softmax_focal_loss_op.cu index 93635269f176..b7f8d2423ebc 100644 --- a/modules/detectron/softmax_focal_loss_op.cu +++ b/modules/detectron/softmax_focal_loss_op.cu @@ -176,6 +176,7 @@ bool SoftmaxFocalLossOp::RunOnDevice() { <<>>( N, A, H, W, Xdata, P->mutable_data(), num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Compute loss for each x,y location const int* Tdata = T.data(); @@ -184,6 +185,7 @@ bool SoftmaxFocalLossOp::RunOnDevice() { 0, context_.cuda_stream()>>>( N, A, H, W, P->data(), Tdata, losses_.mutable_data(), Wdata, gamma_, alpha_, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // sum the losses float* avg_loss_data = avg_loss->mutable_data(); @@ -227,6 +229,8 @@ bool SoftmaxFocalLossGradientOp::RunOnDevice() { 0, context_.cuda_stream()>>>( N, A, H, W, Pdata, Tdata, buff_.mutable_data(), Wdata, gamma_, alpha_, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Compute the gradient with the weights const float* Bdata = buff_.data(); SoftmaxFocalLossGradientKernel @@ -234,6 +238,7 @@ bool SoftmaxFocalLossGradientOp::RunOnDevice() { 0, context_.cuda_stream()>>>( N, D, H, W, Pdata, Tdata, Bdata, d_avg_loss.data(), dX->mutable_data(), num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Scale( dX->size(), scale_, diff --git a/modules/detectron/spatial_narrow_as_op.cu b/modules/detectron/spatial_narrow_as_op.cu index 97ddc492eb07..ff8b5632e80a 100644 --- a/modules/detectron/spatial_narrow_as_op.cu +++ b/modules/detectron/spatial_narrow_as_op.cu @@ -115,6 +115,7 @@ bool SpatialNarrowAsOp::DoRunWithType() { out_width, A.template data(), C->template mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -152,6 +153,7 @@ bool SpatialNarrowAsGradientOp::DoRunWithType() { out_width, dC.template data(), dA->template mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/detectron/upsample_nearest_op.cu b/modules/detectron/upsample_nearest_op.cu index 38af4254f922..0ea32e348c0b 100644 --- a/modules/detectron/upsample_nearest_op.cu +++ b/modules/detectron/upsample_nearest_op.cu @@ -164,6 +164,8 @@ bool UpsampleNearestOp::RunOnDevice() { upscale<<>>( input_data, output_data, no_elements, scale_, d1, d2, d3); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return true; } @@ -209,6 +211,7 @@ bool UpsampleNearestGradientOp::RunOnDevice() { math::Set(no_elements, 0.f, gradInput_data, &context_); downscale<<>>( gradInput_data, gradOutput_data, no_elements, scale_, d1, d2, d3); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/mypy.ini b/mypy.ini index f4b37f15a820..d5b1ed20e081 100644 --- a/mypy.ini +++ b/mypy.ini @@ -101,15 +101,15 @@ ignore_errors = True [mypy-torch.nn.quantized.modules.conv] ignore_errors = True -[mypy-torch._lobpcg] -ignore_errors = True - [mypy-torch._appdirs] ignore_errors = True [mypy-torch._utils] ignore_errors = True +[mypy-torch._overrides] +ignore_errors = True + [mypy-torch.utils.tensorboard._caffe2_graph] ignore_errors = True @@ -131,42 +131,9 @@ ignore_errors = True [mypy-torch.nn.quantized.modules.batchnorm] ignore_errors = True -[mypy-torch.nn.intrinsic.quantized.modules.conv_relu] -ignore_errors = True - -[mypy-torch.nn.intrinsic.quantized.modules.bn_relu] -ignore_errors = True - -[mypy-torch.nn.intrinsic.quantized.modules.linear_relu] -ignore_errors = True - [mypy-torch.nn.intrinsic.qat.modules.conv_fused] ignore_errors = True -[mypy-torch.onnx.operators] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset8] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset9] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset11] -ignore_errors = True - -[mypy-torch.onnx.symbolic_caffe2] -ignore_errors = True - -[mypy-torch.onnx.symbolic_helper] -ignore_errors = True - -[mypy-torch.onnx.symbolic_registry] -ignore_errors = True - -[mypy-torch.onnx.utils] -ignore_errors = True - [mypy-torch.multiprocessing.pool] ignore_errors = True diff --git a/requirements-flake8.txt b/requirements-flake8.txt new file mode 100644 index 000000000000..1e2ba252556f --- /dev/null +++ b/requirements-flake8.txt @@ -0,0 +1,8 @@ +flake8==3.8.2 +flake8-bugbear==20.1.4 +flake8-comprehensions==3.3.0 +flake8-executable==2.0.4 +flake8-pyi==20.5.0 +mccabe +pycodestyle==2.6.0 +pyflakes==2.2.0 diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 4b143fc19827..deb7a161e1d3 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -187,6 +187,8 @@ ("aten::ifft", datetime.date(2021, 1, 31)), ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::quantile", datetime.date(2021, 1, 31)), + ("aten::nanquantile", datetime.date(2021, 1, 31)), ("aten::_fft_with_size", datetime.date(2021, 1, 31)), ] diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index ca4ba0fdb3da..10f36cc8e394 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -739,8 +739,8 @@ void checkScopeCallbacks() { std::string(fn.name().str()) == "test_user_scope") { found_user_scope = true; } - }, - [](const at::RecordFunction&) {})); + return nullptr; + })); bool bad_scope = false; auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) { @@ -752,9 +752,8 @@ void checkScopeCallbacks() { } else { bad_scope = true; } - return true; - }, - [](const at::RecordFunction&) {}) + return nullptr; + }) .scopes({scope})); }; @@ -813,8 +812,8 @@ TEST(RecordFunctionTest, Basic) { } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { ts_names.insert(fn.name().str()); } - }, - [](const RecordFunction&) {}) + return nullptr; + }) .needsInputs(true)); TracedTestInputs eager_inputs, jit_inputs; @@ -851,9 +850,8 @@ TEST(RecordFunctionTest, Basic) { if (std::string(fn.name().str()) == "test") { ++sampled_cb_ctr; } - return true; - }, - [](const RecordFunction&) {}) + return nullptr; + }) .samplingProb(sampling_prob)); }; @@ -863,9 +861,8 @@ TEST(RecordFunctionTest, Basic) { if (std::string(fn.name().str()) == "test") { ++non_sampled_cb_ctr; } - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); auto handle = setup_sampled_callback(0.5); @@ -908,9 +905,8 @@ TEST(RecordFunctionTest, Basic) { [&fn_names, &mtx](const RecordFunction& fn) { std::lock_guard lock(mtx); fn_names.push_back(fn.name().str()); - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); { RecordFunctionGuard g1(false); { @@ -934,8 +930,10 @@ TEST(RecordFunctionTest, Basic) { std::vector ids; auto add_remove_test_add_cb = [&ids](size_t id) { return addGlobalCallback(RecordFunctionCallback( - [&ids, id](const RecordFunction& fn) { ids.push_back(id); }, - [](const RecordFunction&) {})); + [&ids, id](const RecordFunction& fn) { + ids.push_back(id); + return nullptr ; + })); }; auto h1 = add_remove_test_add_cb(1); @@ -972,8 +970,7 @@ TEST(RecordFunctionTest, Basic) { ids.clear(); addGlobalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(1); }, - [](const RecordFunction&) {})); + [&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; })); { RECORD_USER_SCOPE("test"); } @@ -983,8 +980,7 @@ TEST(RecordFunctionTest, Basic) { auto th = std::thread([&ids]() { addThreadLocalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(2); }, - [](const RecordFunction&) {})); + [&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; })); { RECORD_USER_SCOPE("test_thread"); } }); @@ -1070,8 +1066,7 @@ TEST(RecordFunctionTest, Basic) { bool ran = false; should_run = false; addGlobalCallback(RecordFunctionCallback( - [&ran](const RecordFunction& fn) { ran = true; }, - [](const RecordFunction&) {}) + [&ran](const RecordFunction& fn) { ran = true; return nullptr; }) .setShouldRun(shouldRunCallback)); { RECORD_USER_SCOPE("test"); } @@ -1093,8 +1088,8 @@ TEST(RecordFunctionTest, Basic) { auto handle = addThreadLocalCallback(RecordFunctionCallback( [&recorded_op](const RecordFunction& fn) { recorded_op = fn.name().str(); - }, - [](const RecordFunction&) {})); + return nullptr; + })); ThreadLocalState state; std::thread t_child([state]() { ThreadLocalStateGuard g_tls(state); @@ -1111,16 +1106,20 @@ TEST(RecordFunctionTest, Basic) { bool has_ids = false; addGlobalCallback( RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; }, - [](const RecordFunction&) {}) + [&has_ids](const RecordFunction& fn) { + has_ids = fn.handle() > 0; + return nullptr; + }) .needsIds(true)); { RECORD_USER_SCOPE("test"); } TORCH_CHECK(has_ids); clearCallbacks(); has_ids = false; addGlobalCallback(RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; }, - [](const RecordFunction&) {})); + [&has_ids](const RecordFunction& fn) { + has_ids = fn.handle() > 0; + return nullptr; + })); { RECORD_USER_SCOPE("test"); } TORCH_CHECK(!has_ids); clearCallbacks(); @@ -1138,6 +1137,7 @@ TEST(RecordFunctionTest, OperatorNameOverload) { } else { operator_names.insert("No Operator Name"); } + return nullptr; }) .scopes({at::RecordScope::FUNCTION})); auto t = torch::randn({1, 2, 3}, at::kCPU); @@ -1209,9 +1209,8 @@ TEST(ThreadLocalDebugInfoTest, Basic) { [&done](const RecordFunction&) { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); { c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); auto t = torch::randn({1, 2, 3}, at::kCPU); diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp index 910331166d51..c77d89af5afa 100644 --- a/test/cpp/jit/test_module_api.cpp +++ b/test/cpp/jit/test_module_api.cpp @@ -43,6 +43,44 @@ static void import_libs( si.loadType(QualifiedName(class_name)); } +TEST(ModuleAPITest, MethodRunAsync) { + // Module m("m"); + // m.define(R"( + // def forward(self): + // r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + // r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + // return r1.wait() + r2.wait() + // )"); + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + // borrow model file from TEST(GraphExecutorTest, runAsync_executor) + testModelFile.append("test_interpreter_async.pt"); + auto m = load(testModelFile); + + auto counter = 0; + std::mutex mtx; + + auto launcher = [&](std::function f) { + mtx.lock(); + ++counter; + mtx.unlock(); + at::launch(move(f)); + }; + + auto method = m.get_method("forward"); + + std::vector stack; + auto kwargs = std::unordered_map(); + auto future = method.run_async(stack, kwargs, launcher); + + future->wait(); + + // expect 2 forks and 2 wait callbacks being excuted on provided taskLauncher + // but ivalue::Future would be marked completed and release wait before + // finishing all callbacks + ASSERT_GE(counter, 2); +} + TEST(ModuleAPITest, Clone) { auto cu = std::make_shared(); // creating child module diff --git a/test/cpp/tensorexpr/__init__.py b/test/cpp/tensorexpr/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index debee0596489..cf658ad488f6 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -138,7 +138,9 @@ TEST(Kernel, _3) { } } -TEST(Kernel, _4) { +TEST(Kernel, DISABLED_Shape_Inference) { + // disabled: doesn't do stride propagation, and isn't being used currently + // Test TensorExpr shape inference capabilities: it should only require shapes // for the inputs { @@ -396,7 +398,7 @@ TEST(Kernel, CatInputTypesPromotion) { %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): %dim : int = prim::Constant[value=1]() %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] + %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) return (%r))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -465,7 +467,12 @@ at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { } // namespace -TEST(Kernel, SumAllAxes) { +TEST(Kernel, DISABLED_SumAllAxes) { + // [zero-dim tensors] + // NNC does not yet handle zero-dim tensors. aten::sum with no axis + // input returns a zero-dim tensors, so these tests must be disabled + // until we add support for zero-dim tensors. + // Test lowering of sum on all axes. const auto graph_template = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): @@ -512,6 +519,19 @@ TEST(Kernel, SumAllAxes) { } } +std::string li_to_str(at::ArrayRef li) { + std::stringstream out; + bool first = true; + for (auto elem : li) { + if (!first) { + out << ", "; + } + out << elem; + first = false; + } + return out.str(); +} + TEST(Kernel, SumOneAxis) { // Test lowering of sum on one axis. const auto graph_template = R"IR( @@ -519,7 +539,7 @@ TEST(Kernel, SumOneAxis) { %1 : int[] = prim::Constant[value=[${dim}]]() %2 : bool = prim::Constant[value=${keepdim}]() %3 : ${dtype} - %4 : Tensor = aten::sum(%0, %1, %2, %3) + %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) return (%4))IR"; auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -531,17 +551,23 @@ TEST(Kernel, SumOneAxis) { env.d("dim", dim); env.d("keepdim", keepdim); env.s("dtype", dtypeConstant(scalar_type)); - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); c10::optional dtype; if (scalar_type != ScalarType::None) { dtype = static_cast(scalar_type); } auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); + if (scalar_type == ScalarType::None) { + env.s("out_dtype", "Float"); + } else { + env.s("out_dtype", "Double"); + } + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto o = at::empty({}, TensorOptions(kCPU)); TensorExprKernel k(graph); std::vector inputs = {a}; Stmt* s = k.getCodeGenStmt(); @@ -578,7 +604,7 @@ TEST(Kernel, SumMultipleAxes) { %3 : int[] = prim::ListConstruct(%1, %2) %4 : bool = prim::Constant[value=${keepdim}]() %5 : ${dtype} - %6 : Tensor = aten::sum(%0, %3, %4, %5) + %6 : Float(${size}, strides=[${strides}]) = aten::sum(%0, %3, %4, %5) return (%6))IR"; auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -593,13 +619,17 @@ TEST(Kernel, SumMultipleAxes) { env.d("dim2", dim2); env.d("keepdim", keepdim); env.s("dtype", dtypeConstant(ScalarType::None)); + auto o = at::empty({}, TensorOptions(kCPU)); + auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); + + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); auto graph = std::make_shared(); parseIR(graph_string, &*graph); - auto o = at::empty({}, TensorOptions(kCPU)); - auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); TensorExprKernel k(graph); std::vector inputs = {a}; Stmt* s = k.getCodeGenStmt(); @@ -636,7 +666,7 @@ TEST(Kernel, Softmax2D) { graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): %1 : int = prim::Constant[value=${dim}]() %2 : int = prim::Constant[value=7]() - %3 : Tensor = aten::${op}(%0, %1, %2) + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) return (%3))IR"; auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -657,11 +687,15 @@ TEST(Kernel, Softmax2D) { for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { auto softmax_dim_size = a.sizes()[softmax_dim]; auto other_dim = (softmax_dim + 1) % a.dim(); - + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); KernelScope kernel_scope; TemplateEnv env; env.d("dim", softmax_dim); env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); auto graph = std::make_shared(); @@ -685,8 +719,6 @@ TEST(Kernel, Softmax2D) { std::vector stack = fmap(inputs); k.run(stack); auto output = stack[0].toTensor(); - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); ASSERT_EQ(output.sizes(), ref.sizes()); ASSERT_TRUE(at::allclose(output, ref)); } @@ -698,7 +730,7 @@ TEST(Kernel, Softmax3D) { graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): %1 : int = prim::Constant[value=${dim}]() %2 : int = prim::Constant[value=7]() - %3 : Tensor = aten::${op}(%0, %1, %2) + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) return (%3))IR"; auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -727,11 +759,16 @@ TEST(Kernel, Softmax3D) { other_dims.push_back(i); } } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); KernelScope kernel_scope; TemplateEnv env; env.d("dim", softmax_dim); env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); auto graph = std::make_shared(); @@ -758,8 +795,6 @@ TEST(Kernel, Softmax3D) { k.run(stack); auto output = stack[0].toTensor(); - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); ASSERT_EQ(output.sizes(), ref.sizes()); ASSERT_TRUE(at::allclose(output, ref)); } @@ -771,7 +806,7 @@ TEST(Kernel, Softmax4D) { graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): %1 : int = prim::Constant[value=${dim}]() %2 : int = prim::Constant[value=7]() - %3 : Tensor = aten::${op}(%0, %1, %2) + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) return (%3))IR"; auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -803,11 +838,16 @@ TEST(Kernel, Softmax4D) { other_dims.push_back(i); } } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); KernelScope kernel_scope; TemplateEnv env; env.d("dim", softmax_dim); env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); auto graph = std::make_shared(); @@ -835,15 +875,14 @@ TEST(Kernel, Softmax4D) { std::vector stack = fmap(inputs); k.run(stack); auto output = stack[0].toTensor(); - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); ASSERT_EQ(output.sizes(), ref.sizes()); ASSERT_TRUE(at::allclose(output, ref)); } } } -TEST(Kernel, InlineProducerIntoReduction) { +TEST(Kernel, DISABLED_InlineProducerIntoReduction) { + // see : [zero-dim tensors] KernelScope kernel_scope; // Inline producer (mul) into reduction (sum). @@ -882,7 +921,9 @@ TEST(Kernel, InlineProducerIntoReduction) { ASSERT_TRUE(at::allclose(o, ref)); } -TEST(Kernel, InlineReductionIntoConsumer) { +TEST(Kernel, DISABLED_InlineReductionIntoConsumer) { + // see : [zero-dim tensors] + KernelScope kernel_scope; // Inline producer (mul %2) into reduction (sum %4) but DO NOT diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 953c184de1fc..c1d3392fff32 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -160,6 +160,63 @@ TEST(LLVM, ByteToDoubleCastTest) { ASSERT_EQ(cg.value(), 2); } +TEST(LLVM, BitCast) { + constexpr int16_t ref16 = 1337; + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + at::Half reff16 = 1337.0f; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + + // this is broken + /*{ + KernelScope kernel_scope; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(k); + auto b = BitCast::make(kShort, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + KernelScope kernel_scope; + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + KernelScope kernel_scope; + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + KernelScope kernel_scope; + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + KernelScope kernel_scope; + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } +} + TEST(LLVM, LetTest01) { KernelScope kernel_scope; @@ -514,6 +571,32 @@ TEST(LLVM, VectorizerLoadStoreTest) { assertAllEqual(c_vec, 21); } +TEST(LLVM, VectorizeBitCast) { + KernelScope kernel_scope; + Placeholder a(BufHandle("A", {128}, kInt)); + + Tensor* c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) { + return bitcast(a.load(i)); + }); + + Placeholder c_buf(BufHandle(c->buf())); + LoopNest l({c}); + Stmt* s = l.root_stmt(); + l.vectorize(dynamic_cast(s)->front()); + ASSERT_TRUE(dynamic_cast(dynamic_cast(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(128); + std::vector c_vec(128); + for (auto i = 0; i < 128; ++i) { + a_vec[i] = raw_bitcast(1337.f); + } + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 1337.f); +} + TEST(LLVM, MemcpyTest) { KernelScope kernel_scope; constexpr int N = 32; diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 0a8037f28db0..aa44da858abf 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -1484,6 +1484,49 @@ TEST(LoopNest, ScheduleInlineThreeMixedSplit) { ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); } +// Check that inlining works for output tensors too +TEST(LoopNest, ScheduleInlineOutputTensors) { + KernelScope kernel_scope; + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor* x = Compute( + "x", + {{M, "m1"}, {N, "n1"}, {K, "k1"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return m * n * k; + }); + Tensor* y = Compute( + "y", + {{M, "m2"}, {N, "n2"}, {K, "k2"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x->call(m, n, k) + m; + }); + + LoopNest l1({x, y}); + l1.computeInline(x->buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + std::ostringstream oss; + oss << *stmt1; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m1 = 0; m1 < 4; m1++) +# CHECK: for (int n1 = 0; n1 < 5; n1++) +# CHECK: for (int k1 = 0; k1 < 6; k1++) +# CHECK: x[m1, n1, k1] = (n1 * m1) * k1; +# CHECK: for (int m2 = 0; m2 < 4; m2++) +# CHECK: for (int n2 = 0; n2 < 5; n2++) +# CHECK: for (int k2 = 0; k2 < 6; k2++) +# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + TEST(LoopNest, ScheduleFuserStyle) { KernelScope kernel_scope; const int kVectorSize = 8; @@ -3498,5 +3541,199 @@ TEST(LoopNest, CacheWritesSimple) { assertAllEqual(c_data, c_ref); } +TEST(LoopNest, DeadStoreElimination) { + KernelScope kernel_scope; + VarHandle y("y", kInt); + VarHandle x("x_tail", kInt); + BufHandle f("f", {26, 5}, kFloat); + BufHandle g("g", {26, 5}, kFloat); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + For* stmt1 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(f, {x_2, y}, (x_2 + y), 1), + Store::make(g, {x_2, y}, (x_2 * y), 1), + }))); + Stmt* stmt = Block::make({stmt1}); + + // Will eliminate if not used by an output. + LoopNest loop(stmt, {f.node()}, {}, {}); + loop.eliminateDeadStores(); + + std::ostringstream oss; + oss << *loop.root_stmt(); + + const std::string& expected_ir = + R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK-NOT: g[x_tail + 5 * 4, y] + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // But won't eliminate if used by different outputs. + LoopNest loop2(stmt, {f.node(), g.node()}, {}, {}); + loop2.eliminateDeadStores(); + + oss.clear(); + oss << *loop2.root_stmt(); + + const std::string& expected_ir2 = + R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK: g[x_tail + 5 * 4, y] + )IR"; + torch::jit::testing::FileCheck().run(expected_ir2, oss.str()); +} + +TEST(LoopNest, DeadStoreEliminationWithIntermediates) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + BufHandle f("f", {26 * 5}, kFloat); + BufHandle g("g", {26 * 5}, kFloat); + BufHandle h("h", {26, 5}, kFloat); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + For* stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); + For* stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); + For* stmt3 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(h, {x, y}, Load::make(f, {x * y}, 1), 1), + }))); + Stmt* stmt = Block::make({stmt1, stmt2, stmt3}); + + // Will eliminate the write to g, but not f since it used by the producer of + // h. + LoopNest loop(stmt, {h.node()}, {}, {}); + loop.eliminateDeadStores(); + + std::ostringstream oss; + oss << *loop.root_stmt(); + + const std::string& expected_ir = + R"IR( + #CHECK: f[x] = x; + #CHECK-NOT: g[z] = + #CHECK: h[x, y] = f[x * y]; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Sanity check won't eliminate if g is an output. + LoopNest loop2(stmt, {h.node(), g.node()}, {}, {}); + loop2.eliminateDeadStores(); + + oss.clear(); + oss << *loop2.root_stmt(); + + const std::string& expected_ir2 = + R"IR( + #CHECK: f[x] = x; + #CHECK: g[z] = z + 1; + #CHECK: h[x, y] = f[x * y]; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir2, oss.str()); +} + +TEST(LoopNest, CompoundTensorSimple) { + KernelScope kernel_scope; + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}, 1) + x + y, 1)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + Block* body = Block::make({outer_for1, outer_for2}); + + Tensor* A = new CompoundTensor(a_buf.node(), {i.node(), j.node()}, body); + + LoopNest l({A}); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + + Stmt* s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {A}); + + std::vector a_ref(50, 0); + + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 5; ++j) { + a_ref[i * 5 + j] = (i * j) + i + j; + } + } + cg.call({a_data}); + + assertAllEqual(a_data, a_ref); +} + +TEST(LoopNest, CompoundTensorUsed) { + KernelScope kernel_scope; + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}, 1) + x + y, 1)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + Block* body = Block::make({outer_for1, outer_for2}); + + Tensor* A = new CompoundTensor(a_buf.node(), {i.node(), j.node()}, body); + Tensor* B = Compute( + "B", {{10, "i"}, {3, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i, j + 1) + A->call(i, j + 2); + }); + + LoopNest l({B}); + ASSERT_FALSE(l.computeInline(A->buf())); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + std::vector b_data(50, 0); + + Stmt* s = IRSimplifier::simplify(l.root_stmt()); + std::cout << *s << "\n "; + SimpleIREvaluator cg(s, {B}); + + std::vector b_ref(50, 0); + + auto AT = [](int i, int j) { return i * j + i + j; }; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 3; ++j) { + b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); + } + } + cg.call({b_data}); + + assertAllEqual(b_data, b_ref); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp index 0c771733d935..71ad0f5149ac 100644 --- a/test/cpp/tensorexpr/test_type.cpp +++ b/test/cpp/tensorexpr/test_type.cpp @@ -1,5 +1,6 @@ #include +#include "torch/csrc/jit/tensorexpr/eval.h" #include "torch/csrc/jit/tensorexpr/ir.h" #include "torch/csrc/jit/tensorexpr/tensor.h" @@ -42,6 +43,115 @@ TEST(Type, Test01) { } } +TEST(Type, BitCasting) { + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kInt); + } + { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kFloat); + } + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kHalf); + } + { + KernelScope kernel_scope; + VarHandle x("x", kHalf); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kShort); + } + + constexpr int16_t ref16 = 1337; + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + at::Half reff16 = 1337.0f; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + using SimpleIRExprEval = ExprEval; + // this is broken + /*{ + KernelScope kernel_scope; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(*k); + auto b = BitCast::make(kShort, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + KernelScope kernel_scope; + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + KernelScope kernel_scope; + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + KernelScope kernel_scope; + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + KernelScope kernel_scope; + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } + + // This segfaults :( + /*{ + KernelScope kernel_scope; + VarHandle x("x", kDouble); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kLong); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + }*/ +} + TEST(Type, Propagation) { // Same types: { diff --git a/test/cpp_extensions/cuda_extension.cu b/test/cpp_extensions/cuda_extension.cu index 29511af8a0ed..0c23d89df889 100644 --- a/test/cpp_extensions/cuda_extension.cu +++ b/test/cpp_extensions/cuda_extension.cu @@ -6,6 +6,7 @@ #include #include +#include #include @@ -26,4 +27,5 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) { const int threads = 1024; const int blocks = (size + threads - 1) / threads; sigmoid_add_kernel<<>>(x, y, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/test/cpp_extensions/cuda_extension_kernel.cu b/test/cpp_extensions/cuda_extension_kernel.cu index 660219989863..4a942b0a20af 100644 --- a/test/cpp_extensions/cuda_extension_kernel.cu +++ b/test/cpp_extensions/cuda_extension_kernel.cu @@ -1,5 +1,6 @@ #include #include +#include #include @@ -20,4 +21,5 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) { const int threads = 1024; const int blocks = (size + threads - 1) / threads; sigmoid_add_kernel<<>>(x, y, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/test/cpp_extensions/cuda_extension_kernel2.cu b/test/cpp_extensions/cuda_extension_kernel2.cu index 817bdf64ac8e..ddb240e5d067 100644 --- a/test/cpp_extensions/cuda_extension_kernel2.cu +++ b/test/cpp_extensions/cuda_extension_kernel2.cu @@ -1,5 +1,6 @@ #include #include +#include #include @@ -20,4 +21,5 @@ void tanh_add_cuda(const float* x, const float* y, float* output, int size) { const int threads = 1024; const int blocks = (size + threads - 1) / threads; tanh_add_kernel<<>>(x, y, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/test/distributed/_pipeline/sync/conftest.py b/test/distributed/_pipeline/sync/conftest.py index 315431d0b644..561c41d11350 100644 --- a/test/distributed/_pipeline/sync/conftest.py +++ b/test/distributed/_pipeline/sync/conftest.py @@ -5,7 +5,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import pytest +import tempfile import torch +from torch.distributed import rpc @pytest.fixture(autouse=True) @@ -35,3 +37,17 @@ def cuda_sleep(seconds): def pytest_report_header(): return f"torch: {torch.__version__}" + +@pytest.fixture +def setup_rpc(scope="session"): + file = tempfile.NamedTemporaryFile() + rpc.init_rpc( + name="worker0", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(file.name), + ) + ) + yield + rpc.shutdown() diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py index 96ecd84e0d18..90ecd7613d67 100644 --- a/test/distributed/_pipeline/sync/skip/test_gpipe.py +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -17,7 +17,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint): +def test_1to3(balance, checkpoint, setup_rpc): if torch.cuda.device_count() < len(balance): pytest.skip("at least %d cuda devices required" % len(balance)) @@ -61,14 +61,14 @@ def forward(self, input): input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) output = model(input) - loss = output.mean() + loss = output.local_value().mean() loss.backward() - assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) + assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device)) -def test_none_skip(): +def test_none_skip(setup_rpc): @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): @@ -102,7 +102,7 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=None): for next_grad_fn, _ in grad_fn.next_functions: assert_grad_fn_is_not_portal(next_grad_fn, visited) - assert_grad_fn_is_not_portal(output.grad_fn) + assert_grad_fn_is_not_portal(output.local_value().grad_fn) - output.sum().backward() + output.local_value().sum().backward() assert input.grad.mean().item() == 1 diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py index 31c4ea13b9f1..7d03a4e9db49 100644 --- a/test/distributed/_pipeline/sync/skip/test_leak.py +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -29,7 +29,7 @@ def forward(self, input): @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint): +def test_delete_portal_tensor(train, checkpoint, setup_rpc): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function @@ -97,7 +97,7 @@ def forward(self, input): if train: model.train() - output = model(input) + output = model(input).local_value() output.norm().backward() else: model.eval() @@ -106,7 +106,7 @@ def forward(self, input): @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch): +def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): def deny(*args, **kwargs): raise AssertionError("tried to create Portal without Pipe") diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py index 4f5346a837b5..a66b7d006ae1 100644 --- a/test/distributed/_pipeline/sync/test_bugs.py +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -12,7 +12,7 @@ from torch.distributed._pipeline.sync import Pipe -def test_python_autograd_function(): +def test_python_autograd_function(setup_rpc): # A Python autograd function might fail with this error: # # RuntimeError: Returning Variables sharing storage with other Variables @@ -41,10 +41,10 @@ def forward(self, input): x = torch.rand(42) y = model(x) - assert torch.allclose(x, y) + assert torch.allclose(x, y.local_value()) -def test_exception_no_hang(): +def test_exception_no_hang(setup_rpc): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal @@ -69,7 +69,7 @@ def forward(self, x): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep): +def test_tuple_wait(cuda_sleep, setup_rpc): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized @@ -113,7 +113,7 @@ def forward(self, triple): b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) y = model((a, b)) - y.norm().backward() + y.local_value().norm().backward() torch.cuda.synchronize(0) torch.cuda.synchronize(1) @@ -121,7 +121,7 @@ def forward(self, triple): assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) -def test_parallel_randoms(): +def test_parallel_randoms(setup_rpc): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): @@ -133,6 +133,7 @@ def forward(self, x): x = torch.rand(10, 10, requires_grad=True) model = Pipe(model, chunks=10, checkpoint="always") y = model(x) + y = y.local_value() y.norm().backward() assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py index 17b3dac4eca8..3b842dbfb9ab 100644 --- a/test/distributed/_pipeline/sync/test_inplace.py +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -11,12 +11,12 @@ from torch.distributed._pipeline.sync import Pipe -def test_inplace_on_requires_grad(): +def test_inplace_on_requires_grad(setup_rpc): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = Pipe(model, checkpoint="always") x = torch.rand(1) - y = model(x) + y = model(x).local_value() message = r"a leaf Variable that requires grad .* used in an in-place operation." with pytest.raises(RuntimeError, match=message): @@ -24,14 +24,14 @@ def test_inplace_on_requires_grad(): @pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(): +def test_inplace_on_not_requires_grad(setup_rpc): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) model = Pipe(model, [1], devices=["cpu"], checkpoint="always") x = torch.rand(1) - y = model(x) + y = model(x).local_value() del model message = r"a leaf Variable that requires grad .* used in an in-place operation." @@ -40,7 +40,7 @@ def test_inplace_on_not_requires_grad(): @pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(): +def test_inplace_incorrect_grad(setup_rpc): class M(nn.Module): def forward(self, foo_bar): # 'foo' requires grad but 'bar' does not. In-place operation on @@ -62,7 +62,7 @@ def forward(self, foo_bar): foo = torch.tensor([1.0], requires_grad=True) bar = torch.tensor([1.0]) - output = model((foo, bar)) + output = model((foo, bar)).local_value() del model output.backward() diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index c0992c7bc0ed..ad00d9ffc0d8 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -68,7 +68,7 @@ def test_chunks_less_than_1(): with pytest.raises(ValueError): Pipe(model, chunks=-1) -def test_batch_size_indivisible(): +def test_batch_size_indivisible(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=4) @@ -79,7 +79,7 @@ def test_batch_size_indivisible(): assert not record -def test_batch_size_small(): +def test_batch_size_small(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=4) @@ -90,7 +90,7 @@ def test_batch_size_small(): assert not record -def test_checkpoint_mode(): +def test_checkpoint_mode(setup_rpc): def count_grad_fn(grad_fn, name, visited=None): if visited is None: visited = set() @@ -119,9 +119,9 @@ def count_grad_fn(grad_fn, name, visited=None): except_last_output = except_last(input) never_output = never(input) - assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2 - assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1 - assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0 + assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 + assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1 + assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 def test_checkpoint_mode_invalid(): @@ -140,7 +140,7 @@ def test_checkpoint_mode_when_chunks_1(): Pipe(model, chunks=1, checkpoint="never") -def test_checkpoint_eval(): +def test_checkpoint_eval(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=2) input = torch.rand(2, 1) @@ -157,16 +157,16 @@ def find_grad_fn(grad_fn, name): model.train() train_output = model(input) - assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") model.eval() eval_output = model(input) - assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") -def test_checkpoint_non_float_input(): +def test_checkpoint_non_float_input(setup_rpc): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) @@ -183,7 +183,7 @@ def forward(self, input): output.backward() -def test_no_grad(): +def test_no_grad(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=2) input = torch.rand(2, 1) @@ -206,7 +206,7 @@ def hook(module, input, output): assert latent.grad_fn is None -def test_exception(): +def test_exception(setup_rpc): class ExpectedException(Exception): pass @@ -221,7 +221,7 @@ def forward(self, *_): model(torch.rand(1)) -def test_exception_early_stop_asap(): +def test_exception_early_stop_asap(setup_rpc): """Even the first partitions have finished to process, the partition before the failed partition should be killed as soon as possible. """ @@ -258,7 +258,32 @@ def forward(self, x): assert counter == 2 -def test_input_pair(): +def test_nested_input(setup_rpc): + class NestedInput(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, inp): + return inp + + model = nn.Sequential(NestedInput()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + # TypeError: expected Tensor, but got tuple + with pytest.raises(TypeError): + model((a, (a, b))).local_value() + + # TypeError: expected Tensor, but got list + with pytest.raises(TypeError): + model((a, [a, b])).local_value() + + +def test_input_pair(setup_rpc): class Two(nn.Module): def __init__(self): super().__init__() @@ -275,15 +300,26 @@ def forward(self, a_and_b): a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) - a_out, b_out = model((a, b)) + a_out, b_out = model((a, b)).local_value() loss = (a_out + b_out).mean() loss.backward() assert a.grad is not None assert b.grad is not None + # Test with list. + a.grad = None + b.grad = None + a_out, b_out = model([a, b]).local_value() + loss = (a_out + b_out).mean() + loss.backward() + + assert a.grad is not None + assert b.grad is not None -def test_input_singleton(): + + +def test_input_singleton(setup_rpc): class One(nn.Module): def __init__(self): super().__init__() @@ -298,7 +334,19 @@ def forward(self, only_a): a = torch.rand(10, 1, requires_grad=True) - (a_out,) = model((a,)) + (a_out,) = model((a,)).local_value() + loss = a_out.mean() + loss.backward() + + assert all(p.grad is not None for p in model.parameters()) + assert a.grad is not None + + # Test with list + a.grad = None + for p in model.parameters(): + p.grad = None + + (a_out,) = model([a]).local_value() loss = a_out.mean() loss.backward() @@ -306,7 +354,7 @@ def forward(self, only_a): assert a.grad is not None -def test_input_varargs(): +def test_input_varargs(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model) @@ -318,7 +366,7 @@ def test_input_varargs(): model(a, b) -def test_non_tensor(): +def test_non_tensor(setup_rpc): class NonTensor(nn.Module): def forward(self, _): return "hello" @@ -336,7 +384,7 @@ def forward(self, _): model("hello") -def test_non_tensor_tuple(): +def test_non_tensor_sequence(setup_rpc): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") @@ -353,9 +401,13 @@ def forward(self, x): with pytest.raises(TypeError): model((x, "hello")) + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model([x, "hello"]) + @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint): +def test_deferred_batch_norm(checkpoint, setup_rpc): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( @@ -363,7 +415,7 @@ def test_deferred_batch_norm(checkpoint): ) x = torch.rand(4, 3, 10, 10) - pipe(x).mean().backward() + pipe(x).local_value().mean().backward() bn(x).mean().backward() assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) @@ -371,7 +423,7 @@ def test_deferred_batch_norm(checkpoint): @pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint): +def test_deferred_batch_norm_params(checkpoint, setup_rpc): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( @@ -379,7 +431,7 @@ def test_deferred_batch_norm_params(checkpoint): ) x = torch.rand(4, 3, 10, 10) - pipe(x).mean().backward() + pipe(x).local_value().mean().backward() bn(x).mean().backward() assert pipe[0].weight.grad is not None @@ -455,13 +507,13 @@ def test_deny_moving(): model.to(dtype=torch.float) -def test_empty_module(): +def test_empty_module(setup_rpc): # Empty sequential module is not illegal. model = nn.Sequential() model = Pipe(model) - assert model(torch.tensor(42)) == torch.tensor(42) - assert model((torch.tensor(42),)) == (torch.tensor(42),) + assert model(torch.tensor(42)).local_value() == torch.tensor(42) + assert model((torch.tensor(42),)).local_value() == (torch.tensor(42),) # But only tensor or tensors is legal in Pipe. with pytest.raises(TypeError): @@ -518,7 +570,7 @@ def __init__(self, param1, param2): @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") -def test_verify_nested_modules(): +def test_verify_nested_modules(setup_rpc): model = nn.Sequential( nn.Sequential( nn.Linear(32, 16).cuda(0), @@ -532,8 +584,8 @@ def test_verify_nested_modules(): pipe = Pipe(model) out = pipe(torch.rand(10, 32).cuda(0)) - assert out.device == torch.device("cuda:1") - assert out.size() == torch.Size([10, 2]) + assert out.local_value().device == torch.device("cuda:1") + assert out.local_value().size() == torch.Size([10, 2]) def test_verify_module_duplicate_parameters_on_same_device(): class Surrogate(nn.Module): @@ -547,7 +599,7 @@ def __init__(self, module): Pipe(model) -def test_forward_lockstep(): +def test_forward_lockstep(setup_rpc): timeline = [] class DelayedLog(nn.Module): diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py index 3d2c77e8fef4..56ad86de081b 100644 --- a/test/distributed/_pipeline/sync/test_transparency.py +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -10,7 +10,7 @@ from torch.distributed._pipeline.sync import Pipe -def test_simple_linears(): +def test_simple_linears(setup_rpc): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) @@ -33,7 +33,7 @@ def zero_grad(parameters): # With Pipe model = Pipe(model, chunks=4) - outputs = model(inputs) + outputs = model(inputs).local_value() loss = outputs.mean() loss.backward() diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 3b25be6e49c1..4b3e962d835f 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -3268,6 +3268,20 @@ def forward(self, x): loss = criterion(output, target) loss.backward() + @requires_nccl() + @skip_if_not_multigpu + def test_pass_default_pg(self): + dist.init_process_group( + "nccl", + init_method=f"file://{self.file_name}", + world_size=self.world_size, + rank=self.rank, + ) + + default_pg = c10d.distributed_c10d._get_default_group() + dist.destroy_process_group(default_pg) + self.assertFalse(dist.is_initialized()) + @requires_nccl() @skip_if_not_multigpu def test_save_load_checkpoint(self): diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index 1bfa3922bd94..f3161a1f8cb1 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -18,6 +18,7 @@ torch.set_default_dtype(torch.double) +NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") class TestDataParallel(TestCase): @@ -597,6 +598,25 @@ def test_scatter_cpu(self): def test_scatter_gpu(self): self._test_scatter(torch.randn((4, 4)).cuda()) + @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @unittest.skipIf(NO_NCCL, "NCCL needed") + def test_data_parallel_complex(self): + # We expect complex parameters to be broadcast by view_as_real, e.g. move from C to R^2 + class Cplx(torch.nn.Module): + def __init__(self): + super().__init__() + self.cplx = torch.nn.Parameter(torch.zeros(1, 10, dtype=torch.cfloat).cuda()) + + def forward(self, x): + return x + self.cplx + + cplx = torch.nn.DataParallel(Cplx().cuda()) + input = torch.rand(1, 10, dtype=torch.cfloat).cuda() + result = cplx(input) + # 2 is the extra real view dimension here + self.assertEqual(result.size(), torch.Size([1, 10, 2])) + self.assertEqual(result, torch.view_as_real(input)) + def _test_gather(self, output_device): inputs = ( torch.randn(2, 4, device='cuda:0', requires_grad=True), diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index abba69eb472f..b057d12a285d 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -39,12 +39,13 @@ from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests from torch.testing._internal.common_cuda import TEST_CUDA from torch.autograd import grad, gradcheck +from torch.autograd.functional import jacobian from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2, ContinuousBernoulli, Dirichlet, Distribution, Exponential, ExponentialFamily, FisherSnedecor, Gamma, Geometric, Gumbel, - HalfCauchy, HalfNormal, - Independent, Kumaraswamy, Laplace, LogisticNormal, + HalfCauchy, HalfNormal, Independent, Kumaraswamy, + LKJCholesky, Laplace, LogisticNormal, LogNormal, LowRankMultivariateNormal, MixtureSameFamily, Multinomial, MultivariateNormal, NegativeBinomial, Normal, @@ -58,7 +59,8 @@ from torch.distributions.kl import _kl_expfamily_expfamily from torch.distributions.transforms import (AffineTransform, CatTransform, ExpTransform, StackTransform, identity_transform) -from torch.distributions.utils import probs_to_logits, lazy_property +from torch.distributions.utils import (probs_to_logits, lazy_property, tril_matrix_to_vec, + vec_to_tril_matrix) from torch.nn.functional import softmax # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for @@ -246,6 +248,20 @@ def is_all_nan(tensor): 'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(), }, ]), + Example(LKJCholesky, [ + { + 'dim': 2, + 'concentration': 0.5 + }, + { + 'dim': 3, + 'concentration': torch.tensor([0.5, 1., 2.]), + }, + { + 'dim': 100, + 'concentration': 4. + }, + ]), Example(Laplace, [ { 'loc': torch.randn(5, 5, requires_grad=True), @@ -2265,10 +2281,10 @@ def test_gumbel_sample(self): 'Gumbel(loc={}, scale={})'.format(loc, scale)) def test_kumaraswamy_shape(self): - concentration1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) - concentration0 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) - concentration1_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True) - concentration0_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True) + concentration1 = torch.randn(2, 3).abs().requires_grad_() + concentration0 = torch.randn(2, 3).abs().requires_grad_() + concentration1_1d = torch.randn(1).abs().requires_grad_() + concentration0_1d = torch.randn(1).abs().requires_grad_() self.assertEqual(Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3)) self.assertEqual(Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3)) self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,)) @@ -2279,10 +2295,10 @@ def test_kumaraswamy_shape(self): # Kumaraswamy distribution is not implemented in SciPy # Hence these tests are explicit def test_kumaraswamy_mean_variance(self): - c1_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) - c0_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) - c1_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True) - c0_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True) + c1_1 = torch.randn(2, 3).abs().requires_grad_() + c0_1 = torch.randn(2, 3).abs().requires_grad_() + c1_2 = torch.randn(4).abs().requires_grad_() + c0_2 = torch.randn(4).abs().requires_grad_() cases = [(c1_1, c0_1), (c1_2, c0_2)] for i, (a, b) in enumerate(cases): m = Kumaraswamy(a, b) @@ -2534,6 +2550,29 @@ def test_continuous_bernoulli_3d(self): (2, 5, 2, 3, 5)) self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) + def test_lkj_cholesky_log_prob(self): + def tril_cholesky_to_tril_corr(x): + x = vec_to_tril_matrix(x, -1) + diag = (1 - (x * x).sum(-1)).sqrt().diag_embed() + x = x + diag + return tril_matrix_to_vec(x @ x.T, -1) + + for dim in range(2, 5): + log_probs = [] + lkj = LKJCholesky(dim, concentration=1.) + for i in range(2): + sample = lkj.sample() + sample_tril = tril_matrix_to_vec(sample, diag=-1) + log_prob = lkj.log_prob(sample) + log_abs_det_jacobian = torch.slogdet(jacobian(tril_cholesky_to_tril_corr, sample_tril)).logabsdet + log_probs.append(log_prob - log_abs_det_jacobian) + # for concentration=1., the density is uniform over the space of all + # correlation matrices. + if dim == 2: + # for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.) + self.assertTrue(all([x == torch.tensor(0.5).log() for x in log_probs])) + self.assertEqual(log_probs[0], log_probs[1]) + def test_independent_shape(self): for Dist, params in EXAMPLES: for param in params: @@ -4362,6 +4401,22 @@ def test_cat_transform_non_uniform(self): t2.log_abs_det_jacobian(x2, y2)], dim=dim) self.assertEqual(actual_jac, expected_jac) + def test_cat_event_dim(self): + t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) + t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) + dim = 1 + bs = 16 + x1 = torch.randn(bs, 2) + x2 = torch.randn(bs, 2) + x = torch.cat([x1, x2], dim=1) + t = CatTransform([t1, t2], dim=dim, lengths=[2, 2]) + y1 = t1(x1) + y2 = t2(x2) + y = t(x) + actual_jac = t.log_abs_det_jacobian(x, y) + expected_jac = sum([t1.log_abs_det_jacobian(x1, y1), + t2.log_abs_det_jacobian(x2, y2)]) + def test_stack_transform(self): x1 = -1 * torch.arange(1, 101, dtype=torch.float) x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100 diff --git a/test/distributions/test_utils.py b/test/distributions/test_utils.py index b58cfe39fc1c..5751246eb10a 100644 --- a/test/distributions/test_utils.py +++ b/test/distributions/test_utils.py @@ -13,7 +13,7 @@ def test_tril_matrix_to_vec(shape): mat = torch.randn(shape) n = mat.shape[-1] - for diag in range(-n + 1, n): + for diag in range(-n, n): actual = mat.tril(diag) vec = tril_matrix_to_vec(actual, diag) tril_mat = vec_to_tril_matrix(vec, diag) diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py new file mode 100644 index 000000000000..db06663285da --- /dev/null +++ b/test/fx/test_fx_const_fold.py @@ -0,0 +1,274 @@ +import unittest + +import torch +from torch.fx.experimental import const_fold + + +class TestConstFold(unittest.TestCase): + def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule): + self.assertTrue(mod_folded.const_subgraph_module is not None) + + # Check that the constants are attributes in the main subgraph. + num_folded_attrs = 0 + for node in mod_folded.graph.nodes: + if node.op == "get_attr" and (node.target in mod_folded.const_output_names): + num_folded_attrs += 1 + self.assertEqual(num_folded_attrs, len(mod_folded.const_output_names)) + + def test_const_fold_basic_one_attr_no_name_collision(self): + r""" + Perform constant folding conversion, from original mod to split constant folding + module with two split subgraphs, where there's a single attr to fold and + a single output attr result to replace. + + attr1 attr1 + | | | | + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul attr2 x / (input from previous subgraph + \ / \ / is attr) + add sub y + | \ / + output mul attr2 + \ / + add + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]])) + + def forward(self, x, y): + a = self.attr_1 + self.attr_1 + x = x - a + return x * y + self.attr_2 + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_one_attr_name_collision(self): + r""" + Perform constant folding conversion, from original mod to split constant folding + module with two split subgraphs, where there's a single attr to fold and + a single output attr result to replace. Name the attrs such that they will + collide by name with folded attrs. + + add_1 add_1 + | | | | + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul add_2 x / (input from previous subgraph + \ / \ / is attr) + add sub y + | \ / + output mul add_2 + \ / + add + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Note: Named as such to result in name collision. + self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]])) + self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]])) + + def forward(self, x, y): + a = self.add_1__CF + self.add_1__CF + x = x - a + return x * y + self.add_2__CF + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_noop(self): + r""" + Check that a graph with no constant folding is handled correctly. + + x attr1 + \ / + sub + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + + def forward(self, x): + return x - self.attr1 + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + + # Check that the folded graph module is None, since there was no folding to do. + self.assertTrue(mod_folded.const_subgraph_module is None) + + # Now run both folded and non-folded to check results equal. + in_x = torch.tensor([[-0.45]]) + base_result = mod(in_x) + fold_result = mod_folded(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_two_attr_three_input(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into a single output, and there are three placeholder inputs. + + attr1 attr2 attr1 attr2 + \ / \ / + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul z x / (input from previous subgraph + \ / \ / is attr) + div sub y + | \ / + output mul z + \ / + div + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]])) + + def forward(self, x, y, z): + a = self.attr1 + self.attr1 + sub = x - a + mul = sub * y + return mul / z + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y, in_z = ( + torch.tensor([[-0.45]]), + torch.tensor([0.9]), + torch.tensor([1.1]), + ) + base_result = mod(in_x, in_y, in_z) + fold_result = mod_folded(in_x, in_y, in_z) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_two_attr(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into a single output. + + attr1 attr2 attr1 attr2 + \ / \ / + x add add (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + sub x | (input from previous subgraph is attr) + | \ / + output sub + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.randn(2, 3)) + self.attr2 = torch.nn.Parameter(torch.randn(2, 3)) + + def forward(self, x): + y = self.attr1 + self.attr2 + return x + y + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x = torch.randn(2, 3) + fold_result = mod_folded(in_x) + base_result = mod(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_multi_const_folded_attrs(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into two new attrs. + + attr1 attr2 attr1 attr2 + / \ | / \ | + permute | sum permute | sum + \ / / \ / | + x add y / add | + \ / \ / | | + sub add output output (become attrs add_1 and mul_1) + \ / ==> --------+-------+------ (const/base subgraph split) + \ / x | y | (inputs from previous subgraph + add \ / \ / are attrs) + | sub add + linear \ / + | add + sigmoid | + | linear + output | + sigmoid + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.randn(4, 4)) + self.attr2 = torch.nn.Parameter(torch.randn(4, 4)) + self.lin = torch.nn.Linear(4, 4) + + def forward(self, x, y): + a = self.attr1 + self.attr1.permute(1, 0) + x = x - a + amax = torch.sum(self.attr2, dim=1) + y = y + amax + return torch.sigmoid(self.lin(x + y)) + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.randn(4, 4), torch.randn(4) + fold_result = mod_folded(in_x, in_y) + base_result = mod(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index dafc95013b96..04991f72c352 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -109,22 +109,31 @@ def fn(x): return a def test_del_multiple_operands(self): + def fn(x): + # type: (List[int]) -> List[int] + a, b, c = x[0], x[1], x[2] + del a, b, c + return x - with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, - "with more than one operand"): - @torch.jit.script - def del_list_multiple_operands(x): - # type: (List[int]) -> List[int] - del x[0], x[1] - return x + self.checkScript(fn, ([1, 2, 3],)) - with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, - "with more than one operand"): - @torch.jit.script - def del_dict_multiple_operands(x): - # type: (Dict[str, int]) -> Dict[str, int] - del x['hi'], x['there'] - return x + def del_list_multiple_operands(x): + # type: (List[int]) -> List[int] + del x[0], x[1] + return x + + py_out = del_list_multiple_operands([0, 1, 2]) + jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) + self.assertEquals(py_out, jit_out) + + def del_dict_multiple_operands(x): + # type: (Dict[str, int]) -> Dict[str, int] + del x['hi'], x['there'] + return x + + py_out = del_dict_multiple_operands({"hi": 5, "there": 6}) + jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) + self.assertEquals(py_out, jit_out) class TestTensorBuiltins(JitTestCase): diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index b4075dba14c8..a80670f0d22b 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -959,6 +959,26 @@ def forward(self, x): # Make sure class constant is accessible from module self.assertEqual(m.w, m_loaded.w) + def test_py_class_to_ivalue_missing_attribute(self): + global Foo # see [local resolution in python] + + class Foo(object): + i : int + f : float + + def __init__(self, i : int, f : float): + self.i = i + self.f = f + + @torch.jit.script + def test_fn(x : Foo) -> float: + return x.i + x.f + + test_fn(Foo(3, 4.0)) + + with self.assertRaisesRegex(RuntimeError, 'missing attribute i'): + test_fn(torch.rand(3, 4)) + def test_unused_method(self): """ Test unused methods on scripted classes. diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py index e42f8225a3d6..e763a730c473 100644 --- a/test/jit/test_profiler.py +++ b/test/jit/test_profiler.py @@ -25,7 +25,8 @@ def setUp(self): self.default_dtype = torch.get_default_dtype() self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True) torch.set_default_dtype(torch.double) - + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) def tearDown(self): torch._C._jit_set_profiling_executor(self.prev_exec) @@ -35,6 +36,7 @@ def tearDown(self): torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu) torch.set_default_dtype(self.default_dtype) torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) def test_tensor_type_not_determined_by_inputs(self): @torch.jit.script @@ -212,6 +214,19 @@ def foo(a, b): g = torch.jit.last_executed_optimized_graph() FileCheck().check("fallback_function").check_next("CallFunction").run(g) + def test_tensor_constant(self): + def foo(a, b): + return a + b + torch.tensor([2]) + + x = torch.ones(1, requires_grad=False) + foo_script = torch.jit.script(foo) + foo_script(x, x) + foo_script(x, x) + + self.assertEqual(foo_script(x, x), foo(x, x)) + g = torch.jit.last_executed_optimized_graph() + FileCheck().check_count("aten::add", 2, exactly=True).run(g) + def test_iterative_fusion(self): @torch.jit.script def foo(a, b, c, d): diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index af7897e159b3..31eec81d480a 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -240,6 +240,10 @@ def forward(self): traced = torch.jit.trace(TryTracing(), ()) self.assertEqual(torch.zeros(4, 4), traced()) + def test_torchbind_pass_wrong_type(self): + with self.assertRaisesRegex(RuntimeError, 'missing attribute capsule'): + torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4)) + def test_torchbind_tracing_nested(self): class TryTracingNest(torch.nn.Module): def __init__(self): diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 8ccf0fdfdb89..f6fa533d7837 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -776,7 +776,7 @@ def forward(self, x_in): return x_out x = {torch.tensor(1.): torch.randn(1, 2, 3)} - self.assertONNX(MyModel(), (x,)) + self.assertONNX(MyModel(), (x, {})) def test_dict_str(self): class MyModel(torch.nn.Module): @@ -786,7 +786,7 @@ def forward(self, x_in): return x_out x = {"test_key_in": torch.randn(1, 2, 3)} - self.assertONNX(MyModel(), (x,)) + self.assertONNX(MyModel(), (x, {})) def test_arange_dynamic(self): class TestModel(torch.nn.Module): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index e2e12af88c1e..c481d58e4bb5 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -81,14 +81,19 @@ def run_model_test(self, model, batch_size=2, state_dict=None, if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) - with torch.no_grad(): if isinstance(input, torch.Tensor): input = (input,) # In-place operators will update input tensor data as well. # Thus inputs are replicated before every forward call. - input_copy = copy.deepcopy(input) - output = model(*input_copy) + if isinstance(input, dict): + input = (input,) + input_args = copy.deepcopy(input) + input_kwargs = {} + if isinstance(input_args[-1], dict): + input_kwargs = input_args[-1] + input_args = input_args[:-1] + output = model(*input_args, **input_kwargs) if isinstance(output, torch.Tensor): output = (output,) @@ -459,7 +464,7 @@ def forward(self, x_in): return x_out x = {torch.tensor(1.): torch.randn(1, 2, 3)} - self.run_test(MyModel(), (x,)) + self.run_test(MyModel(), (x, {})) @disableScriptTest() def test_dict_str(self): @@ -470,7 +475,101 @@ def forward(self, x_in): return x_out x = {"test_key_in": torch.randn(1, 2, 3)} - self.run_test(MyModel(), (x,)) + self.run_test(MyModel(), (x, {})) + + def test_optional_inputs_with_no_optionals(self): + class NoOptionalModel(torch.nn.Module): + def forward(self, input): + return input + + # Without empty optional arguments dictionary + x = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (x,)) + # With empty optional arguments dictionary + y = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (y, {})) + + def test_optional_inputs_with_mixed_optionals(self): + class MixedModel(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(MixedModel(), (x, y, None)) + self.run_test(MixedModel(), (x, None, z)) + # With optional arguments dictionary + self.run_test(MixedModel(), (x, {'y': y, 'z': None})) + self.run_test(MixedModel(), (x, {'y': None, 'z': z})) + self.run_test(MixedModel(), (x, {'z': z})) + self.run_test(MixedModel(), (x, {'y': y})) + + def test_optional_inputs_with_all_optionals(self): + class AllOptionalModel(torch.nn.Module): + def forward(self, y=None, z=None): + if y is not None: + return y + if z is not None: + return z + + y = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(AllOptionalModel(), (y, None)) + # With optional arguments dictionary + self.run_test(AllOptionalModel(), {'y': y, 'z': None}) + + def test_input_names_with_optional_args(self): + class NoOptionalModel(torch.nn.Module): + def forward(self, input): + return input + + # Without empty optional arguments dictionary + x = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (x,), input_names=['input_x']) + # With empty optional arguments dictionary + y = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (y, {})) + + class MixedModel(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(MixedModel(), (x, y, None), input_names=['input_x', 'input_y']) + self.run_test(MixedModel(), (x, None, z), input_names=['input_x', 'input_z']) + + # With optional arguments dictionary + self.run_test(MixedModel(), (x, {'y': y, 'z': None}), input_names=['input_x', 'input_y']) + self.run_test(MixedModel(), (x, {'y': None, 'z': z}), input_names=['input_x', 'input_z']) + + class AllOptionalModel(torch.nn.Module): + def forward(self, y=None, z=None): + if y is not None: + return y + if z is not None: + return z + + y = torch.randn(2, 3) + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(AllOptionalModel(), (y, None), input_names=['input_y']) + self.run_test(AllOptionalModel(), (None, z), input_names=['input_z']) + # With optional arguments dictionary + self.run_test(AllOptionalModel(), {'y': y, 'z': None}, input_names=['input_y']) + self.run_test(AllOptionalModel(), {'y': None, 'z': z}, input_names=['input_z']) @disableScriptTest() def test_none_as_input(self): @@ -754,7 +853,10 @@ def forward(self, x): return x.transpose(0, 1) x = torch.randn(32, 3, 64, 64) - self.run_test(TransposeModule(), x) + y = torch.randn(16, 3, 8, 64) + self.run_test(TransposeModule(), x, input_names=['x'], + dynamic_axes={'x': [0, 2]}, + test_with_inputs=[y]) def squeeze_model_tests(self, d, x1, x2): class Squeeze(torch.nn.Module): @@ -841,7 +943,10 @@ def forward(self, x): def test_maxpool_adaptive(self): model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False) x = torch.randn(20, 16, 50, requires_grad=True) - self.run_test(model, x) + y = torch.randn(32, 16, 50, requires_grad=True) + self.run_test(model, x, input_names=['x'], + dynamic_axes={'x' : [0]}, + test_with_inputs=[y]) def test_maxpool_2d(self): model = torch.nn.MaxPool2d(5, padding=(1, 2)) @@ -903,7 +1008,10 @@ def test_avgpool_2d_ceil(self): def test_avgpool_3d_ceil(self): model = torch.nn.AvgPool3d(3, 2, ceil_mode=True) x = torch.randn(20, 16, 50, 44, 31) - self.run_test(model, x) + y = torch.randn(32, 8, 50, 44, 31) + self.run_test(model, x, input_names=['x'], + dynamic_axes={'x' : [0, 1]}, + test_with_inputs=[y]) @skipIfUnsupportedMinOpsetVersion(9) def test_floating_point(self): @@ -3809,7 +3917,11 @@ def forward(self, x): return x.unfold(dimension=2, size=2, step=2) x = torch.randn(4, 2, 3, requires_grad=True) - self.run_test(UnfoldModel(), x) + y = torch.randn(2, 1, 3, requires_grad=True) + self.run_test(UnfoldModel(), x, + dynamic_axes={'x': [0, 1]}, + input_names=['x'], + test_with_inputs=[y]) @skipIfONNXShapeInference(False) def test_unfold_infer_shape(self): @@ -3826,6 +3938,21 @@ def forward(self, x): x = torch.randn(32, 3, 64) self.run_test(UnfoldModule(), x) + def test_prelu(self): + class PReluModel(torch.nn.Module): + def __init__(self): + super(PReluModel, self).__init__() + self.prelu = torch.nn.PReLU() + + def forward(self, x): + return self.prelu(x) + + x = torch.randn(2, 3, 4) + y = torch.randn(2, 4, 5) + self.run_test(PReluModel(), x, input_names=['x'], + dynamic_axes={'x': [1, 2]}, + test_with_inputs=[y]) + def test_remainder(self): class RemainderModel(torch.nn.Module): def forward(self, input, other): @@ -3862,6 +3989,15 @@ def forward(self, input): x = torch.randint(10, (2, 3)) self.run_test(FModModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_glu(self): + class GluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.glu(x) + + x = torch.randn(2, 4, 5, 6, requires_grad=True) + self.run_test(GluModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_gelu(self): class GeluModel(torch.nn.Module): diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index 745437a86ca3..9e6379a29cec 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -352,10 +352,9 @@ def test_resnet_base(self): with override_quantized_engine(qengine): qconfig = torch.quantization.get_default_qconfig(qengine) model = ResNetBase().float().eval() + model.fuse_model() model = QuantWrapper(model) model.qconfig = qconfig - fuse_list = ['module.conv1', 'module.bn1', 'module.relu1'] - fuse_modules(model, fuse_list, inplace=True) model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.img_data_2d) @@ -365,6 +364,8 @@ def checkQuantized(model): self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d) self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) + self.assertEqual(type(model.module.fc), nnq.Linear) + test_only_eval_fn(model, self.img_data_2d) self.checkNoQconfig(model) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 7e4048b98cbf..7c6c548f2594 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -14,14 +14,21 @@ prepare_qat_fx, ) +from torch.quantization.fx.pattern_utils import ( + is_match, + MatchAllNode, +) + from torch.quantization import ( QuantType, QuantStub, DeQuantStub, + QuantWrapper, quant_type_to_str, default_qconfig, default_dynamic_qconfig, default_qat_qconfig, + per_channel_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig, get_default_qconfig, @@ -30,6 +37,7 @@ prepare, prepare_qat, convert, + quantize_dynamic, default_placeholder_observer, PerChannelMinMaxObserver, QConfigDynamic, @@ -44,7 +52,15 @@ skip_if_no_torchvision, train_one_epoch, run_ddp, + test_only_eval_fn, + test_only_train_fn, +) + +from torch.testing._internal.common_quantization import ( LinearModelWithSubmodule, + ResNetBase, + RNNDynamicModel, + RNNCellDynamicModel, ) from torch.testing._internal.common_quantized import ( @@ -186,6 +202,31 @@ def forward(self, x): @skipIfNoFBGEMM class TestQuantizeFx(QuantizationTestCase): + def test_pattern_match(self): + """ test MatchAllNode with + conv - bn - add - relu pattern + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + self.bn = nn.BatchNorm2d(1) + self.relu = nn.ReLU() + + def forward(self, x, y): + x = self.conv(x) + x = self.bn(x) + x = x + y + x = self.relu(x) + return x + + pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) + m = torch.fx.symbolic_trace(M()) + modules = dict(m.named_modules()) + for n in m.graph.nodes: + if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: + self.assertTrue(is_match(modules, n, pattern)) + def _get_conv_linear_test_cases(self): ''' Returns a list of test cases, with format: is_dynamic, ModuleClass, module_constructor_inputs, @@ -2070,6 +2111,49 @@ def forward(self, indices, offsets): # make sure it runs m(*inputs) + def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): + options = itertools.product(qconfigs, module_type_strs) + for qconfig, module_type_str in options: + model_eager = M(module_type_str).eval() + model_graph = copy.deepcopy(model_eager) + if torch.backends.quantized.engine == 'qnnpack' and \ + qconfig is float16_dynamic_qconfig: + continue + # fp16 dynamic quant is not supported for qnnpack + + eager_qconfig_dict = {x : qconfig for x in module_types} + model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) + + graph_qconfig_dict = { + "object_type": [ + (x, qconfig) for x in module_types + ] + } + model_graph = prepare_fx(model_graph, graph_qconfig_dict) + model_graph = convert_fx(model_graph) + self.assertEqual(model_eager(sample_input), model_graph(sample_input)) + self.checkScriptable(model_graph, [[sample_input]], True) + + def test_rnn_cell(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] + module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float) + self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) + + def test_rnn(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTM'] + module_types = [torch.nn.LSTM] + niter = 10 + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) + self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( self, mode, name, model, eager_quantizable_model, @@ -2189,6 +2273,58 @@ def _test_model_impl( ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name])) + def _test_building_block(self, quant_type, BB): + eager = BB().float() + graph = copy.deepcopy(eager) + + if quant_type == QuantType.STATIC: + qconfig = default_qconfig + eager_prepare = prepare + graph_prepare = prepare_fx + eager.eval() + graph.eval() + calibrate_or_train = test_only_eval_fn + data = self.img_data_2d + else: + assert quant_type == QuantType.QAT + qconfig = default_qat_qconfig + eager_prepare = prepare_qat + graph_prepare = prepare_qat_fx + eager.train() + graph.train() + calibrate_or_train = test_only_train_fn + data = self.img_data_2d_train + + if hasattr(eager, "fuse_model"): + eager.fuse_model() + eager = QuantWrapper(eager) + eager.qconfig = qconfig + eager = eager_prepare(eager) + + qconfig_dict = {"": qconfig} + graph = graph_prepare(graph, qconfig_dict) + + eager_out = eager(data[0][0]) + graph_out = graph(data[0][0]) + self.assertEqual(eager_out, graph_out) + + calibrate_or_train(eager, data) + calibrate_or_train(graph, data) + + eager = convert(eager) + graph = convert_fx(graph) + + eager_out = eager(data[0][0]) + graph_out = graph(data[0][0]) + self.assertEqual(eager_out, graph_out) + + @override_qengines + def test_resnet_base(self): + models = [ResNetBase] + options = itertools.product(self.static_quant_types, models) + for quant_type, M in options: + self._test_building_block(quant_type, M) + @skip_if_no_torchvision @skipIfNoFBGEMM @unittest.skip("skip for now since tbb failed") diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index f67a585d99b6..d8022befeed0 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -81,6 +81,7 @@ class TestQuantizeJitPasses(QuantizationTestCase): """ Test graph mode quantization passes used by quantize_jit """ + def test_foldbn_trivial(self): bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} @@ -2708,6 +2709,28 @@ def test_conv_with_benchmark_flag(self): FileCheck().check("quantized::conv2d") \ .run(converted_model.graph) + @skipIfNoFBGEMM + def test_cat_linear(self): + class LinearModel(torch.nn.Module): + def __init__(self): + super(LinearModel, self).__init__() + self.weight = torch.randn(5, 5) + + def forward(self, x, y): + a = torch.cat([x, y]) + b = F.linear(a, self.weight) + c = F.linear(b, self.weight) + return b, c + + model = LinearModel().eval() + qconfig = {'' : default_qconfig} + float_model = torch.jit.script(model) + prepared_model = prepare_jit(float_model, qconfig) + prepared_model(torch.rand(5, 5), torch.rand(5, 5)) + converted_model = convert_jit(prepared_model) + FileCheck().check("quantized::linear") \ + .check("quantized::linear") \ + .run(converted_model.graph) class TestQuantizeDynamicJitPasses(QuantizationTestCase): def test_prepare_dynamic(self): diff --git a/test/run_test.py b/test/run_test.py index 3687459a4a70..54cc33ebc484 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -162,6 +162,28 @@ 'distributed/rpc/test_process_group_agent', 'distributed/rpc/test_tensorpipe_agent', 'distributed/test_distributed_fork', + 'distributed/_pipeline/sync/skip/test_api', + 'distributed/_pipeline/sync/skip/test_gpipe', + 'distributed/_pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/_pipeline/sync/skip/test_leak', + 'distributed/_pipeline/sync/skip/test_portal', + 'distributed/_pipeline/sync/skip/test_stash_pop', + 'distributed/_pipeline/sync/skip/test_tracker', + 'distributed/_pipeline/sync/skip/test_verify_skippables', + 'distributed/_pipeline/sync/test_balance', + 'distributed/_pipeline/sync/test_bugs', + 'distributed/_pipeline/sync/test_checkpoint', + 'distributed/_pipeline/sync/test_copy', + 'distributed/_pipeline/sync/test_deferred_batch_norm', + 'distributed/_pipeline/sync/test_dependency', + 'distributed/_pipeline/sync/test_inplace', + 'distributed/_pipeline/sync/test_microbatch', + 'distributed/_pipeline/sync/test_phony', + 'distributed/_pipeline/sync/test_pipe', + 'distributed/_pipeline/sync/test_pipeline', + 'distributed/_pipeline/sync/test_stream', + 'distributed/_pipeline/sync/test_transparency', + 'distributed/_pipeline/sync/test_worker', ] ROCM_BLOCKLIST = [ diff --git a/test/test_autograd.py b/test/test_autograd.py index 7c2082b1ed1d..0d99169f4d65 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -29,7 +29,7 @@ record_function, emit_nvtx) import torch.autograd.functional as autogradF from torch.utils.checkpoint import checkpoint -from torch.testing._internal.common_utils import (TEST_WITH_ROCM, TestCase, run_tests, skipIfNoLapack, +from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack, suppress_warnings, slowTest, load_tests, random_symmetric_matrix, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck) @@ -4927,7 +4927,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', 'exp', 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul', - 'addcdiv', 'linalg.tensorinv', 'matrix_exp'] + separate_complex_tests + 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', ] + separate_complex_tests def add_test( name, @@ -6181,10 +6181,6 @@ def test_min_max_median_backprops_to_all_values(self, device): self.assertEqual(x.grad.sum(), 1.) self.assertEqual((x.grad == 1 / 3).sum(), 3) - # skip this test if running on rocm, because in cdist - # we use __shfl_down_sync on CUDA for fast reduction - # and it gives incorrect results on rocm platform - @skipCUDAIfRocm def test_cdist(self, device): def _test_cdist_for_size(sizex, sizey=None): if sizey is None: @@ -6268,8 +6264,6 @@ def test_parameter_resize(self, device): m = torch.cat((asd, asd)) m.sum().backward() - # NOTE: flaky on ROCm CI - @skipCUDAIfRocm def test_sparse_ctor_getter_backward(self, device): # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test def _test(size, sparse_dim, nnz, device): @@ -6590,7 +6584,6 @@ def test_ctc_loss_cudnn(self, device): grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) - @skipCUDAIfRocm def test_leaky_relu_inplace_with_neg_slope(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), -2) @@ -6602,7 +6595,6 @@ def test_leaky_relu_inplace_with_neg_slope(self, device): with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): b.backward(torch.ones(2, device=device)) - @skipCUDAIfRocm def test_leaky_relu_inplace_with_zero_slope(self, device): a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), 0.0) @@ -7325,9 +7317,7 @@ def backward(ctx, *grad): instantiate_device_type_tests( TestAutogradDeviceType, globals(), - # Exclude ROCM for now, there are a lot of failures. See - # https://github.com/pytorch/pytorch/issues/30845 - except_for='cuda' if TEST_WITH_ROCM else None + except_for=None ) if __name__ == '__main__': diff --git a/test/test_cuda.py b/test/test_cuda.py index 2a5754523876..498d7e71620e 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2421,7 +2421,6 @@ def _worker(t): self.assertEqual(results[t].sum().item(), size * size) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @skipIfRocm def test_cudnn_multiple_threads_same_device(self): # This function is intended to test the lazy creation and reuse of per-thread # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp. @@ -2896,6 +2895,209 @@ def test_max_large_axis(self): def test_to_numpy(self): self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_capture_simple(self): + s1 = torch.cuda.Stream() + + with torch.cuda.stream(s1): + a = torch.zeros((1000,), device="cuda") + a += 1 + g = torch.cuda._Graph() + g.capture_begin() + a += 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s1) + + g.replay() + g.replay() + + self.assertTrue(a.sum().item() == 3000.) + + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_rng_functional(self): + # The caching allocator isn't yet graph-safe. + # In this test, graphed regions try to ensure allocator safety by + # stashing references to all temporaries. This is why we use _fused_dropout + # instead of a public dropout API: _fused_dropout returns the mask temporary + # as well as the output, so we can stash references to both. + # + # TODO: + # Switch to public dropout API when the allocator is made graph-safe. + ops_with_kwargs = ((torch._fused_dropout, {"p": 0.1}), + (torch.nn.functional.rrelu, {"training": True}),) + size = 10000 + + def run(op, kwargs): + a = torch.randn((size,), device="cuda", dtype=torch.float) + + torch.cuda.manual_seed(5) + + # Control + eager_out = a + for _ in range(6): + out = op(eager_out, **kwargs) + # _fused_dropout returns a tuple, rrelu returns a bare tensor. + eager_out = out[0] if isinstance(out, tuple) else out + + graph_in = a.clone() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # warms up allocator so no mallocs occur in capture + refs = () + graph_out = graph_in + for _ in range(3): + out = op(graph_out, **kwargs) + refs += tuple(out) + graph_out = out[0] if isinstance(out, tuple) else out + del out, refs, graph_out + + torch.cuda.manual_seed(5) + + refs = () + g = torch.cuda._Graph() + g.capture_begin() + graph_out = graph_in + for _ in range(2): + out = op(graph_out, **kwargs) + refs += tuple(out) + graph_out = out[0] if isinstance(out, tuple) else out + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + # Runs a graphed->eager->graphed sequence of RNG ops. + # replay() plays 2 invocations of the op, so the sequence has 6 + # invocations total, matching Control. + # replay() reads from graph_in and writes to graph_out. + g.replay() + out = op(graph_out, **kwargs) + out = op(out[0], **kwargs)[0] if isinstance(out, tuple) else op(out, **kwargs) + graph_in.copy_(out) + g.replay() + + # If replay() updated RNG state correctly, graph_out + # should now hold data equal to eager_out. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op, kwargs in ops_with_kwargs: + run(op, kwargs) + + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_rng_distributions(self): + # The caching allocator isn't yet graph-safe. + # In this test, all ops maintain static references to inputs and outputs + # that persist across replay(), so they should be safe to test with graphs, + # EXCEPT for multinomial which is a complicated compound op. + # + # TODO: + # Uncomment multinomial when the allocator is made graph-safe. + size = 10000 + input = torch.rand((size,), device="cuda", dtype=torch.float) + alloc = torch.empty((size,), device="cuda", dtype=torch.float) + + # Torch ops to test with sample args (tuple) and kwargs (dict) + torch_with_args = (("bernoulli", (input.clone(),), {}), + # ("multinomial", (input.clone(), size, True), {}), + # ("multinomial", (input.clone(), size // 2, False), {}), + ("normal", (input.clone() + 1, input.clone()), {}), + ("poisson", (input.clone(),), {}), + ("rand", (size,), {"device": "cuda", "dtype": torch.float}), + ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), + ("randn", (size,), {"device": "cuda", "dtype": torch.float}),) + + # Tensor methods to test with sample args (tuple) + tensor_with_args = (("bernoulli_", (input.clone(),)), + ("cauchy_", ()), + ("exponential_", ()), + ("geometric_", (0.3,)), + ("log_normal_", ()), + ("normal_", ()), + ("random_", ()), + ("uniform_", ()),) + + def run(module, op, args, kwargs): + torch.cuda.manual_seed(5) + + # Each path runs a dummy op to increment the state a bit before creating controls. + if (module == "torch"): + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + control1 = alloc.clone() + control2 = alloc.clone() + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) + + g = torch.cuda._Graph() + if (module == "torch"): + g.capture_begin() + t1 = getattr(torch, op)(*args, **kwargs) + t2 = getattr(torch, op)(*args, **kwargs) + g.capture_end() + else: + t1 = alloc.clone() + t2 = alloc.clone() + g.capture_begin() + getattr(t1, op)(*args) + getattr(t2, op)(*args) + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + try: + self.assertNotEqual(control1, t1) + self.assertNotEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # Runs a dummy op prelude, as for controls, to make sure replay() + # picks up the dummy op's state increment. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + getattr(dummy, op)(*args) + + # Runs RNG ops that fill t1 and t2. + g.replay() + + try: + self.assertEqual(control1, t1) + self.assertEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op_with_args in torch_with_args: + run("torch", *op_with_args) + + for meth_with_args in tensor_with_args: + # Adds an empty dict for kwargs, which none of the Tensor methods use + run("Tensor", *(meth_with_args + ({},))) + class TestCudaComm(TestCase): def _test_broadcast(self, input): @@ -3280,5 +3482,6 @@ class TestNamedTupleInput_1(NamedTuple): self.assertEqual(expected_a, x.a) self.assertEqual(expected_b, x.b) + if __name__ == '__main__': run_tests() diff --git a/test/test_dataloader.py b/test/test_dataloader.py index a1afc216d42a..047297c438b7 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1454,6 +1454,15 @@ def test_random_sampler_len_with_replacement(self): self.assertEqual(int(math.ceil(float(num_samples) / batch_size)), count_num_samples_in_data_loader) + def test_distributed_sampler_invalid_rank(self): + from torch.utils.data.distributed import DistributedSampler + dataset = torch.IntTensor(range(10)) + with self.assertRaisesRegex(ValueError, "Invalid rank"): + sampler = DistributedSampler(dataset, 3, 3) + + with self.assertRaisesRegex(ValueError, "Invalid rank"): + sampler = DistributedSampler(dataset, 3, -1) + def test_duplicating_data_with_drop_last(self): from torch.utils.data.distributed import DistributedSampler diff --git a/test/test_foreach.py b/test/test_foreach.py index eff6d969c5e5..c55c4e71dab0 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -53,6 +53,9 @@ class TestForeach(TestCase): (torch._foreach_log1p, torch._foreach_log1p_, torch.log1p, True, False), (torch._foreach_round, torch._foreach_round_, torch.round, False, False), (torch._foreach_frac, torch._foreach_frac_, torch.frac, False, False), + (torch._foreach_reciprocal, torch._foreach_reciprocal_, torch.reciprocal, True, True), + (torch._foreach_sigmoid, torch._foreach_sigmoid_, torch.sigmoid, True, False), + (torch._foreach_trunc, torch._foreach_trunc_, torch.trunc, False, False), # See test_abs # (torch._foreach_abs, torch._foreach_abs_, torch.abs, True, True), @@ -173,7 +176,7 @@ def test_unary_ops(self, device, dtype): control_dtype = torch.float32 if (self.device_type == 'cuda' and (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype - if self.device_type == 'cpu' and dtype == torch.half and torch_op not in [torch.neg, torch.frac]: + if self.device_type == 'cpu' and dtype == torch.half and torch_op not in [torch.neg, torch.frac, torch.reciprocal]: with self.assertRaisesRegex(RuntimeError, r"not implemented for \'Half\'"): expected = [torch_op(tensors1[i]) for i in range(N)] @@ -191,13 +194,14 @@ def test_unary_ops(self, device, dtype): break if dtype in [torch.complex64, torch.complex128] and not support_complex: - # not using assertRaisesRegex due to different error messages - with self.assertRaises(RuntimeError): - expected = [torch_op(tensors1[i]) for i in range(N)] + if not (self.device_type == 'cpu' and torch_op in [torch.sigmoid]): + # not using assertRaisesRegex due to different error messages + with self.assertRaises(RuntimeError): + expected = [torch_op(tensors1[i]) for i in range(N)] - with self.assertRaises(RuntimeError): - res = fe_op(tensors1) - break + with self.assertRaises(RuntimeError): + res = fe_op(tensors1) + break expected = [torch_op(tensors1[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)] res = fe_op(tensors1) diff --git a/test/test_fx.py b/test/test_fx.py index af11f9615cb6..221fae3c518a 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1154,6 +1154,12 @@ def forward(self): m = FooBar1234() self.checkGraphModule(m, ()) + def test_namedtuple_return_trace(self): + class NamedTupReturn(torch.nn.Module): + def forward(self, x): + return Pair(x, x) + + traced = symbolic_trace(NamedTupReturn()) if __name__ == '__main__': run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 57201ded332e..6e9c877b8de6 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,13 +1,14 @@ import torch import unittest import sys -from typing import Dict +from typing import Callable, Dict, Union, List from torch.fx.symbolic_trace import symbolic_trace from torch.fx.graph_module import GraphModule from torch.fx.node import Node from torch.fx.experimental import graph_manipulation from torch.fx.experimental.accelerator_partitioner import Partitioner from torch.fx.experimental.rewriter import RewritingTracer +from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase from torch.fx.experimental.subgraph_creation_example import split_module @@ -20,7 +21,6 @@ PartitionMode ) from torch.fx.experimental.fuser import fuse -from typing import Union, Callable try: from torchvision.models import resnet18 @@ -161,6 +161,37 @@ def forward(self, a, b): catch_runtime_error = True assert catch_runtime_error + def test_large_node_error(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + linear = self.linear(a) + add = linear + a + return add + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a]) + partitioner = Partitioner() + devices = [ + Device("dev_0", 40, 0), + Device("dev_1", 40, 0), + Device("dev_2", 40, 0), + Device("dev_3", 40, 0), + Device("dev_4", 40, 0) + ] + partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + catch_runtime_error = False + try: + ret = partitioner.partition_graph(traced, m, partitioner_config) + except RuntimeError: + catch_runtime_error = True + assert catch_runtime_error + def test_partition_node_manipulation(self): class TestModule(torch.nn.Module): def forward(self, a, b): @@ -187,7 +218,6 @@ def forward(self, a, b): partition.remove_node(selected_node) assert(partition.used_mem_bytes == 80) - def test_size_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): @@ -779,6 +809,40 @@ def forward(self, x): t = torch.randn(2, 2) self.assertEqual(module.Foo()(t), mod(t)) + def test_fetch(self): + attrs_for_lowering: Dict[str, List[str]] = { + "torch.nn.modules.conv.Conv2d": [ + "weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode" + ], + "torch.nn.modules.batchnorm.BatchNorm2d": [ + "weight", "bias", "running_mean", "running_var", "eps" + ], + } + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 2) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, a): + a = self.conv(a) + a += a + return self.bn(a) + + mod = TestModule() + traced = symbolic_trace(mod) + lift_lowering_attrs_to_nodes(traced) + + for node in traced.graph.nodes: + if node.op == "call_module": + assert hasattr(node, "attrs_for_lowering") + para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] + + # node.attrs_for_lowering has an addition field of class name + assert len(para_list) + 1 == len(node.attrs_for_lowering) + for p_name in para_list: + assert p_name in node.attrs_for_lowering if __name__ == "__main__": diff --git a/test/test_indexing.py b/test/test_indexing.py index f3430a158d89..b92fd94e8cbd 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -764,7 +764,7 @@ def test_int_indices(self, device): @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) @dtypesIfCPU(torch.float, torch.long, torch.bool, torch.bfloat16) - @dtypesIfCUDA(torch.half, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.long, torch.bool, torch.bfloat16) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) diff --git a/test/test_jit.py b/test/test_jit.py index 836066a7f84b..df3a857b485f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1248,56 +1248,6 @@ def forward(self, x): FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph) - def test_reconstruct_scopes(self): - class SubModule(torch.nn.Module): - def __init__(self): - super(SubModule, self).__init__() - - def bar(self, x): - return x + x - - def forward(self, x): - return x * self.bar(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - self.sub = SubModule() - - def forward(self, x): - return self.sub(x) + x - - traced = torch.jit.trace(MyModule(), torch.zeros(1)) - g = traced.graph - torch._C._jit_pass_inline(g) - torch._C._jit_pass_reconstruct_scopes(traced._c, g) - FileCheck().check("scope: top(MyModule).sub(SubModule).forward").run(g) - - def test_reconstruct_scopes_duplicated_class_types(self): - class SubModule(torch.nn.Module): - def __init__(self): - super(SubModule, self).__init__() - - def forward(self, x): - return x + 2 - - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - self.sub1 = SubModule() - self.sub2 = SubModule() - - def forward(self, x): - return self.sub1(x) + self.sub2(x) - - traced = torch.jit.trace(MyModule(), torch.zeros(1)) - g = traced.graph - torch._C._jit_pass_inline(g) - torch._C._jit_pass_reconstruct_scopes(traced._c, g) - FileCheck().check_dag("scope: top(MyModule).sub1(SubModule).forward") \ - .check_dag("scope: top(MyModule).sub2(SubModule).forward") \ - .run(g) - def test_expand_quantlint(self): pass @@ -12128,10 +12078,10 @@ def tuple_slice(a): scripted_fn = torch.jit.script(tuple_slice) self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) tuple_graph = scripted_fn.graph - slices = tuple_graph.findAllNodes("prim::TupleSlice") + slices = tuple_graph.findAllNodes("prim::TupleConstruct") num_outputs = set(len(x.output().type().elements()) for x in slices) - # one tuple slice should have an output with 2 elements, other 4 - self.assertTrue(num_outputs == {2, 4}) + # there should be only one tupleSlice with length of 2 + self.assertTrue(num_outputs == {2}) self.run_pass('lower_all_tuples', tuple_graph) self.assertTrue('Tuple' not in str(tuple_graph)) @@ -12142,6 +12092,26 @@ def test_indexing_end_out_of_bounds(): self.assertEqual(test_indexing_end_out_of_bounds(), ()) + def test_stepped_tuple_slicing(self): + + def check_slicing_tuple(slicing, tuple_type, tuple): + template = dedent(""" + def func(x): + # type: ({}) -> Any + return x{} + """) + self._check_code(template.format(tuple_type, slicing), "func", [tuple]) + + check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2)) + check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[5:7:-2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)) + check_slicing_tuple("[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + def test_lower_nested_tuples(self): @torch.jit.script def test(): @@ -15733,7 +15703,7 @@ def fn(*inputs, **kwargs): # alias annotation testing if not is_magic_method and test_name not in EXCLUDE_SCRIPT and not exclude_tensor_method(name, test_name): - check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable) + check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable, aten_name=name) check(name) inplace_name = name + '_' diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index bd5f7ae3af6e..8b04418fa640 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -71,6 +71,19 @@ def setUp(self): torch._C._jit_set_texpr_fuser_enabled(True) self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + self.int_dtypes = [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bool, + ] + self.fp_dtypes = [ + torch.float16, + torch.float32, + torch.float64, + ] + self.dtypes = self.int_dtypes + self.fp_dtypes def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) @@ -461,21 +474,13 @@ def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.bool, - ] binary_ops = [ operator.__and__, operator.__or__, operator.__xor__ ] devices = self.devices - for dtype, op, device in product(dtypes, binary_ops, devices): + for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -500,20 +505,12 @@ def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.bool, - ] binary_ops = [ torch.min, torch.max ] devices = self.devices - for dtype, op, device in product(dtypes, binary_ops, devices): + for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1211,22 +1208,10 @@ def data_for(self, dtype, device="cuda", size=None): else: return v.to(dtype) - @unittest.skipIf(not LLVM_ENABLED, "TODO: bugs in ir eval") def test_unary_ops(self): def apply(fn): return lambda x: fn(x) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] unary_ops = [ torch.lgamma, torch.sigmoid, @@ -1260,11 +1245,10 @@ def apply(fn): torch.trunc, torch.frac, lambda x: torch.threshold(x, 0, -10), - # FIXME: fails on cpu with dtype=uint8 - # lambda x: torch.clamp(x, -10, 10), + lambda x: torch.clamp(x, -10, 10), ] sizes = [(1,), (2,), (4, 4)] - for dtype, op, device, size in product(dtypes, unary_ops, self.devices, sizes): + for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): try: x = self.data_for(dtype, device, size=size) fn = apply(op) @@ -1288,18 +1272,7 @@ def test_binary_ops(self): def apply(fn): return lambda x, y: fn(x, y) - dtypes = [ - # FIXME: Fails in IR Eval: torch.int8 and_ cpu - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] + # FIXME: Fails in IR Eval: torch.int8 and_ cpu binary_ops = [ operator.__and__, operator.__or__, @@ -1331,7 +1304,7 @@ def apply(fn): torch.remainder, ] devices = self.devices - for dtype, op, device in product(dtypes, binary_ops, devices): + for dtype, op, device in product(self.dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1357,18 +1330,7 @@ def test_binary_tensor_scalar_ops(self): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - # FIXME: Fails in IR Eval: torch.int64 and_ cpu - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool - ] + # FIXME: Fails in IR Eval: torch.int64 and_ cpu binary_ops = [ operator.__and__, operator.__or__, @@ -1378,11 +1340,9 @@ def apply_with_scalar(fn, scalar): torch.mul, torch.eq, torch.ne, - - # FIXME: fails with dtype=uint8, scalar=-1 - # torch.ge, - # torch.lt, - # torch.gt, + torch.ge, + torch.lt, + torch.gt, # FIXME: segfaults on CPU backend # operator.__rshift__, @@ -1392,7 +1352,7 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] - for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1415,17 +1375,6 @@ def test_binary_div_ops(self): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool - ] binary_ops = [ torch.div, torch.remainder, @@ -1435,7 +1384,7 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 - for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1459,7 +1408,6 @@ def apply_with_scalar(fn, scalar): dtypes = [ torch.int8, - torch.uint8, torch.int16, torch.int32, torch.int64, @@ -1500,23 +1448,12 @@ def test_ternary_ops(self): def apply(fn): return lambda x, y, z: fn(x, y, z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] ternary_ops = [ torch.lerp, torch.addcmul, ] devices = self.devices - for dtype, op, device in product(dtypes, ternary_ops, devices): + for dtype, op, device in product(self.dtypes, ternary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1542,22 +1479,11 @@ def test_list_ops(self): def apply(fn): return lambda x, y, z: fn([x * x, y * y, z * z]) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] devices = self.devices list_ops = [ torch.cat, ] - for dtype, op, device in product(dtypes, list_ops, devices): + for dtype, op, device in product(self.dtypes, list_ops, devices): try: x = self.data_for(dtype, device, size=[5, 4, 1, 7]) y = self.data_for(dtype, device, size=[5, 4, 1, 7]) @@ -1582,24 +1508,13 @@ def test_where_ops(self): def apply(fn): return lambda cond, x, y: fn(cond, x, y) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] ops = [ torch.where, lambda cond, x, y: torch.where(cond, x, 3.1415), lambda cond, x, y: torch.where(cond, 42, y), ] devices = self.devices - for dtype, op, device in product(dtypes, ops, devices): + for dtype, op, device in product(self.dtypes, ops, devices): try: cond = self.data_for(torch.bool, device) x = self.data_for(dtype, device) @@ -1626,6 +1541,7 @@ def fn(x): return x * x + x unsupported_dtypes = [ + torch.uint8, torch.bfloat16, torch.complex32, torch.complex64, @@ -1694,13 +1610,14 @@ def eager(t1, t2, t3, t4, t: float): t = torch.rand(8, dtype=torch.float, device='cuda') scripted = self.checkScript(eager, (t, t, t, t, 0.1)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_chunk_mul_one(self): - def eager(x): - z, y, w = torch.chunk(x, 3, -1) - return z * 3, y, w - x = torch.rand(64, 1, 3072, dtype=torch.float, device='cuda') - script = self.checkScript(eager, (x,)) + for device in self.devices: + def eager(x): + z, y, w = torch.chunk(x, 3, -1) + return z * 3, y, w + x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) + z, y, w = eager(x) + script = self.checkScript(eager, (x,)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_eq_unsqueeze_type_as(self): diff --git a/test/test_kernel_launch_checks.py b/test/test_kernel_launch_checks.py index 079a7182a1fc..698a5cda2a42 100644 --- a/test/test_kernel_launch_checks.py +++ b/test/test_kernel_launch_checks.py @@ -26,9 +26,16 @@ def test_check_code(self): """)) # Does it work for macros? - self.assertEqual(0, check_code_for_cuda_kernel_launches(""" -#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \\ + self.assertEqual(0, check_code_for_cuda_kernel_launches(r""" +#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \ C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + indexAddSmallIndex \ + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); """)) def test_check_cuda_launches(self): diff --git a/test/test_linalg.py b/test/test_linalg.py index 3fa677d2b1de..5e7e0c273dcf 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -867,23 +867,61 @@ def test_kron_errors_and_warnings(self, device, dtype): # as expected, according to the function's documentation @skipCUDAIfNoMagma def test_norm_dtype(self, device): - def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype): + def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype): + # Determine the best dtype to use for comparisons between tensors + # of two different types + def get_compare_dtype(type0, type1): + types_32bit_based = [torch.float, torch.cfloat] + is_complex = type0.is_complex or type1.is_complex + + if type0 in types_32bit_based or type1 in types_32bit_based: + return torch.cfloat if is_complex else torch.float + else: + return torch.cdouble if is_complex else torch.double + + compare_dtype = get_compare_dtype(from_dtype, to_dtype) + + def get_value_type(dtype): + if dtype == torch.cfloat: + return torch.float + elif dtype == torch.cdouble: + return torch.double + elif dtype == torch.complex32: + return torch.float16 + else: + return dtype + msg = ( f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' f'from_dtype={from_dtype}, to_dtype={to_dtype}') input = torch.randn(*input_size, dtype=from_dtype, device=device) - result = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=from_dtype) - self.assertEqual(result.dtype, from_dtype, msg=msg) - result_converted = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) - self.assertEqual(result_converted.dtype, to_dtype, msg=msg) - self.assertEqual(result.to(compare_dtype), result_converted.to(compare_dtype), msg=msg) + result = torch.linalg.norm(input, ord, keepdim=keepdim) + if from_dtype.is_complex: + # By default, norm downgrades a complex input to the corresponding real number type + self.assertEqual(result.dtype, get_value_type(from_dtype), msg=msg) + else: + self.assertEqual(result.dtype, from_dtype, msg=msg) - result_out_converted = torch.empty_like(result_converted) - torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_converted) - self.assertEqual(result_out_converted.dtype, to_dtype, msg=msg) - self.assertEqual(result_converted, result_out_converted, msg=msg) + result_out = torch.empty((), dtype=to_dtype, device=device) + torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out) + self.assertEqual(result_out.dtype, to_dtype, msg=msg) + self.assertEqual(result.to(compare_dtype), result_out.to(compare_dtype), msg=msg) - ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] + result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) + self.assertEqual(result_with_dtype.dtype, to_dtype, msg=msg) + + if from_dtype.is_complex: + result_convert_first = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim) + self.assertEqual(result_with_dtype.to(compare_dtype), result_convert_first.to(compare_dtype), msg=msg) + else: + self.assertEqual(result.to(compare_dtype), result_with_dtype.to(compare_dtype), msg=msg) + + result_out_with_dtype = torch.empty_like(result_with_dtype) + torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype) + self.assertEqual(result_out_with_dtype.dtype, to_dtype, msg=msg) + self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg) + + ord_vector = [0, 0.1, -0.1, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None] S = 10 test_cases = [ @@ -893,15 +931,16 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) for keepdim in [True, False]: for input_size, ord_settings in test_cases: for ord in ord_settings: - # float to double - run_test_case(input_size, ord, keepdim, torch.float, torch.double, torch.float) - # double to float - run_test_case(input_size, ord, keepdim, torch.double, torch.double, torch.float) + dtypes = [torch.float, torch.double, torch.cfloat, torch.cdouble] + for from_dtype, to_dtype in itertools.product(dtypes, dtypes): + run_test_case(input_size, ord, keepdim, from_dtype, to_dtype) # Make sure that setting dtype != out.dtype raises an error dtype_pairs = [ (torch.float, torch.double), (torch.double, torch.float), + (torch.cfloat, torch.cdouble), + (torch.cdouble, torch.cfloat), ] for keepdim in [True, False]: for input_size, ord_settings in test_cases: @@ -1008,11 +1047,6 @@ def run_test_case(input, p): for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) for p in norm_types: - # frobenius norm not supported for complex tensors - if dtype.is_complex and p == 'fro': - with self.assertRaisesRegex(RuntimeError, "frobenius norm not supported for complex tensors"): - torch.linalg.cond(input, p) - continue run_test_case(input, p) # test empty batch sizes @@ -1040,7 +1074,7 @@ def run_test_case(input, p): for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) for p in ['fro', 2]: - expected_dtype = a.real.dtype if dtype.is_complex and p == 2 else dtype + expected_dtype = a.real.dtype if dtype.is_complex else dtype expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device) actual = torch.linalg.cond(input, p) self.assertEqual(actual, expected) @@ -1068,7 +1102,7 @@ def test_cond_errors_and_warnings(self, device, dtype): # if non-empty out tensor with wrong shape is passed a warning is given a = torch.ones((2, 2), dtype=dtype, device=device) for p in ['fro', 2]: - real_dtype = a.real.dtype if dtype.is_complex and p == 2 else dtype + real_dtype = a.real.dtype if dtype.is_complex else dtype out = torch.empty(a.shape, dtype=real_dtype, device=device) with warnings.catch_warnings(record=True) as w: # Trigger warning @@ -1231,8 +1265,7 @@ def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): for ord in ord_settings: run_error_test_case(input, ord, dim, keepdim, error_type, error_regex) - # Test complex number inputs for linalg.norm. Some cases are not supported yet, so - # this test also verifies that those cases raise an error. + # Test complex number inputs for linalg.norm @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.cfloat, torch.cdouble) @@ -1241,72 +1274,95 @@ def gen_error_message(input_size, ord, keepdim, dim=None): return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( input_size, ord, keepdim, dim) - if self.device_type == 'cpu': - supported_vector_ords = [0, 1, 3, inf, -1, -2, -3, -inf] - supported_matrix_ords = ['nuc', 1, 2, inf, -1, -2, -inf] - unsupported_vector_ords = [ - (2, r'norm with p=2 not supported for complex tensors'), - (None, r'norm with p=2 not supported for complex tensors'), - ] - unsupported_matrix_ords = [ - ('fro', r'frobenius norm not supported for complex tensors'), - (None, r'norm with p=2 not supported for complex tensors'), - ] - - elif self.device_type == 'cuda': - supported_vector_ords = [inf, -inf] - supported_matrix_ords = [1, inf, -1, -inf] - unsupported_vector_ords = [ - (0, r'norm_cuda" not implemented for \'Complex'), - (1, r'norm_cuda" not implemented for \'Complex'), - (2, r'norm with p=2 not supported for complex tensors'), - (-1, r'norm_cuda" not implemented for \'Complex'), - (-2, r'norm_cuda" not implemented for \'Complex'), - (None, r'norm with p=2 not supported for complex tensors'), - ] - unsupported_matrix_ords = [ - (None, r'norm with p=2 not supported for complex tensors'), - ('fro', r'frobenius norm not supported for complex tensors'), - ] + vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf] + matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf] # Test supported ords for keepdim in [False, True]: # vector norm x = torch.randn(25, device=device, dtype=dtype) xn = x.cpu().numpy() - for ord in supported_vector_ords: + for ord in vector_ords: res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() expected = np.linalg.norm(xn, ord, keepdims=keepdim) msg = gen_error_message(x.size(), ord, keepdim) self.assertEqual(res.shape, expected.shape, msg=msg) self.assertEqual(res, expected, msg=msg) + res_out = torch.Tensor().to(device) + torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) + self.assertEqual(res_out.shape, expected.shape, msg=msg) + self.assertEqual(res_out.cpu(), expected, msg=msg) + # matrix norm x = torch.randn(25, 25, device=device, dtype=dtype) xn = x.cpu().numpy() - for ord in supported_matrix_ords: - # TODO: Need to fix abort when nuclear norm is given cdouble input: - # "double free or corruption (!prev) Aborted (core dumped)" - if ord == 'nuc' and dtype == torch.cdouble: - continue + for ord in matrix_ords: res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() expected = np.linalg.norm(xn, ord, keepdims=keepdim) msg = gen_error_message(x.size(), ord, keepdim) self.assertEqual(res.shape, expected.shape, msg=msg) self.assertEqual(res, expected, msg=msg) - # Test unsupported ords - # vector norm - x = torch.randn(25, device=device, dtype=dtype) - for ord, error_msg in unsupported_vector_ords: - with self.assertRaisesRegex(RuntimeError, error_msg): - torch.linalg.norm(x, ord) + res_out = torch.Tensor().to(device) + torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) + self.assertEqual(res_out.shape, expected.shape, msg=msg) + self.assertEqual(res_out.cpu(), expected, msg=msg) + + # Test complex number inputs for linalg.norm + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.cfloat, torch.cdouble) + def test_norm_complex_autograd(self, device, dtype): + def gen_error_message(input_size, ord, keepdim, dim=None): + return "complex norm autograd failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( + input_size, ord, keepdim, dim) + + if dtype == torch.cfloat: + dtype_real = torch.float + elif dtype == torch.cdouble: + dtype_real = torch.double + else: + raise RuntimeError(f'dtype not supported in this test: {dtype}') + + vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf] + matrix_ords = [None, 'fro', 1, inf, -1, -inf] + + # TODO: Fix autograd for matrix orders 'nuc', 2, and -2 by adding complex + # support to svd's backward method. Once this is done, these ords + # should be added to `matrix_ords` above + matrix_ords_unsupported = ['nuc', 2, -2] + + def run_test_case(x, ord, keepdim): + res = torch.linalg.norm(x, ord, keepdim=keepdim) + res.backward() + + x_real = x.clone().detach().abs().requires_grad_(True) + res_real = torch.linalg.norm(x_real, ord, keepdim=keepdim) + res_real.backward() + + msg = gen_error_message(x.size(), ord, keepdim) + + self.assertEqual(res.shape, res_real.shape, msg=msg) + self.assertEqual(res, res_real, msg=msg) + self.assertEqual(x.grad.abs(), x_real.grad, msg=msg) + + # Test supported ords + for keepdim in [False, True]: + for ord in vector_ords: + x = torch.randn(25, dtype=dtype, device=device, requires_grad=True) + run_test_case(x, ord, keepdim) + + for ord in matrix_ords: + x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True) + run_test_case(x, ord, keepdim) - # matrix norm - x = torch.randn(25, 25, device=device, dtype=dtype) - for ord, error_msg in unsupported_matrix_ords: - with self.assertRaisesRegex(RuntimeError, error_msg): - torch.linalg.norm(x, ord) + for ord in matrix_ords_unsupported: + x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True) + with self.assertRaisesRegex( + RuntimeError, + r'svd does not support automatic differentiation for outputs with complex dtype'): + res = torch.linalg.norm(x, ord, keepdim=keepdim) # Test that linal.norm gives the same result as numpy when inputs # contain extreme values (inf, -inf, nan) @@ -1370,12 +1426,6 @@ def run_test_case(input, ord, dim, keepdim, should_error): with self.assertRaises(RuntimeError): torch.linalg.norm(input, ord, dim, keepdim) else: - if dtype in [torch.cfloat, torch.cdouble] and ord in [2, None]: - # TODO: Once these ord values have support for complex numbers, - # remove this error test case - with self.assertRaises(RuntimeError): - torch.linalg.norm(input, ord, dim, keepdim) - return result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) result = torch.linalg.norm(input, ord, dim, keepdim) self.assertEqual(result, result_numpy, msg=msg) @@ -1402,12 +1452,6 @@ def run_test_case(input, ord, dim, keepdim, should_error): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_norm_matrix_degenerate_shapes(self, device, dtype): def run_test_case(input, ord, dim, keepdim, should_error): - if dtype in [torch.cfloat, torch.cdouble] and ord in ['fro', None]: - # TODO: Once these ord values have support for complex numbers, - # remove this error test case - with self.assertRaises(RuntimeError): - torch.linalg.norm(input, ord, dim, keepdim) - return msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' input_numpy = input.cpu().numpy() if should_error: @@ -1495,8 +1539,8 @@ def test_eig_basic(self, device, dtype): self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part self.assertEqual(vv, np_v) - @onlyCPU @skipCPUIfNoLapack + @skipCUDAIfNoMagma @dtypes(torch.double, torch.float) def test_eig_reuse(self, device, dtype): X = torch.randn(4, 4, dtype=dtype, device=device) @@ -1512,15 +1556,15 @@ def test_eig_reuse(self, device, dtype): atol = 1e-8 rtol = 0 self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') - self.assertFalse(v.is_contiguous(), 'V is contiguous') + self.assertTrue(v.is_contiguous(), 'V is not contiguous') torch.eig(X, True, out=(e, v)) Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t())) self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') - self.assertFalse(v.is_contiguous(), 'V is contiguous') + self.assertTrue(v.is_contiguous(), 'V is not contiguous') - @onlyCPU @skipCPUIfNoLapack + @skipCUDAIfNoMagma @dtypes(torch.double, torch.float) def test_eig_non_contiguous(self, device, dtype): X = torch.randn(4, 4, dtype=dtype, device=device) @@ -1546,19 +1590,19 @@ def test_eig_invalid_input(self, device, dtype): # test invalid input self.assertRaisesRegex( RuntimeError, - 'A should be 2 dimensional', + 'input should be 2 dimensional', lambda: torch.eig(torch.ones((2)))) self.assertRaisesRegex( RuntimeError, - 'A should be square', + 'input should be square', lambda: torch.eig(torch.ones((2, 3)))) self.assertRaisesRegex( RuntimeError, - 'A should not contain infs or NaNs', + 'input should not contain infs or NaNs', lambda: torch.eig(np.inf * torch.ones((2, 2)))) self.assertRaisesRegex( RuntimeError, - 'A should not contain infs or NaNs', + 'input should not contain infs or NaNs', lambda: torch.eig(np.nan * torch.ones((2, 2)))) @skipCUDAIfNoMagma @@ -1668,39 +1712,26 @@ def gen_error_message(input_size, p, keepdim, dim=None): return "complex norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % ( input_size, p, keepdim, dim) - if device == 'cpu': - for keepdim in [False, True]: - # vector norm - x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device) - xn = x.cpu().numpy() - for p in [0, 1, 3, inf, -1, -2, -3, -inf]: - res = x.norm(p, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # matrix norm - x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device) - xn = x.cpu().numpy() - for p in ['nuc']: - res = x.norm(p, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # TODO: remove error test and add functionality test above when 2-norm support is added - with self.assertRaisesRegex(RuntimeError, r'norm with p=2 not supported for complex tensors'): - x = torch.randn(2, device=device, dtype=torch.complex64).norm(p=2) - - # TODO: remove error test and add functionality test above when frobenius support is added - with self.assertRaisesRegex(RuntimeError, r'frobenius norm not supported for complex tensors'): - x = torch.randn(2, 2, device=device, dtype=torch.complex64).norm(p='fro') + for keepdim in [False, True]: + # vector norm + x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device) + xn = x.cpu().numpy() + for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]: + res = x.norm(p, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) - elif device == 'cuda': - with self.assertRaisesRegex(RuntimeError, r'"norm_cuda" not implemented for \'ComplexFloat\''): - (1j * torch.randn(25)).norm() + # matrix norm + x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device) + xn = x.cpu().numpy() + for p in ['nuc', 'fro']: + res = x.norm(p, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations @dtypes(torch.float) @@ -1895,6 +1926,7 @@ def test_cholesky_solve_batched_non_contiguous(self, device, dtype): self.assertEqual(x, x_exp) @slowTest + @skipCUDAIf(True, "See https://github.com/pytorch/pytorch/issues/48996") @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -2588,151 +2620,6 @@ def test_old_matrix_rank(self, device, dtype): self.assertEqual(torch.matrix_rank(aaT, True), np.linalg.matrix_rank(aaT.cpu().numpy(), True)) self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True)) - @dtypes(torch.double) - def test_einsum(self, device, dtype): - def check(equation, *operands): - ref = np.einsum(equation, *[operand.cpu().numpy() for operand in operands]) - res = torch.einsum(equation, operands) - self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) - - # Check autograd - ops = [op.detach().requires_grad_() for op in operands] - self.assertTrue(torch.autograd.gradcheck(lambda *ops: torch.einsum(equation, ops), ops)) - for op in ops: - self.assertTrue(op._version == 0) - - # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f - x = torch.rand(5, device=device, dtype=dtype) - y = torch.rand(7, device=device, dtype=dtype) - A = torch.randn(3, 5, device=device, dtype=dtype) - B = torch.randn(2, 5, device=device, dtype=dtype) - C = torch.randn(2, 3, 5, device=device, dtype=dtype) - D = torch.randn(2, 5, 7, device=device, dtype=dtype) - E = torch.randn(7, 9, device=device, dtype=dtype) - F = torch.randn(2, 3, 3, 5, device=device, dtype=dtype) - G = torch.randn(5, 4, 6, device=device, dtype=dtype) - H = torch.randn(4, 4, device=device, dtype=dtype) - I = torch.rand(2, 3, 2, device=device, dtype=dtype) - - # Note: gradcheck fails if the same input is given multiple times which is why the - # calls to clone below. (see https://github.com/pytorch/pytorch/issues/9282) - - # Vector operations - check('i->', x) # sum - check('i,i->', x, x.clone()) # dot - check('i,i->i', x, x.clone()) # vector element-wisem mul - check('i,j->ij', x, y) # outer - - # Matrix operations - check("ij->ji", A) # transpose - check("ij->j", A) # row sum - check("ij->i", A) # col sum - check("ij,ij->ij", A, A.clone()) # matrix element-wise mul - check("ij,j->i", A, x) # matrix vector multiplication - check("ij,kj->ik", A, B) # matmul - check("ij,ab->ijab", A, E) # matrix outer product - - # Tensor operations - check("aij,ajk->aik", C, D) # batch matmul - check("ijk,jk->i", C, A) # tensor matrix contraction - check("aij,jk->aik", D, E) # tensor matrix contraction - check("abcd,dfg->abcfg", F, G) # tensor tensor contraction - check("ijk,jk->ik", C, A) # tensor matrix contraction with double indices - check("ijk,jk->ij", C, A) # tensor matrix contraction with double indices - check("ijk,ik->j", C, B) # non contiguous - check("ijk,ik->jk", C, B) # non contiguous with double indices - - # Test diagonals - check("ii", H) # trace - check("ii->i", H) # diagonal - check('iji->j', I) # non-contiguous trace - - # Test ellipsis - check("i...->...", H) - check("ki,...k->i...", A.t(), B) - check("k...,jk->...", A.t(), B) - check('...ik, ...j -> ...ij', C, x) - check('bik,k...j->i...j', C, torch.rand(5, 3, device=device, dtype=dtype)) - check('i...j, ij... -> ...ij', C, torch.rand(2, 5, 2, 3, device=device, dtype=dtype)) - - # torch.bilinear with discontiguous tensors - l = torch.randn(10, 5, device=device, dtype=dtype).transpose(0, 1) - r = torch.randn(20, 5, device=device, dtype=dtype).transpose(0, 1) - w = torch.randn(15, 10, 20, device=device, dtype=dtype) - check("bn,anm,bm->ba", l, w, r) - - # with strided tensors - check("bn,anm,bm->ba", l[:, ::2], w[:, ::2, ::2], r[:, ::2]) - - def test_einsum_corner_cases(self, device): - def check(equation, *operands, expected_output): - tensors = [torch.tensor(operand, dtype=torch.float32, device=device) if not isinstance(operand, tuple) - else torch.rand(operand, dtype=torch.float32, device=device) for operand in operands] - output = torch.einsum(equation, tensors) - self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) - - # Test equation variantions - check(' ', 1, expected_output=1) - check(' -> ', 1, expected_output=1) - check(' , ', 2, 2, expected_output=4) - check(' , , ', 2, 2, 2, expected_output=8) - check(' , -> ', 2, 2, expected_output=4) - check(' i ', [1], expected_output=[1]) - check(' i -> ', [1], expected_output=1) - check(' i -> i ', [1], expected_output=[1]) - check(' i , i ', [2], [2], expected_output=4) - check(' i , i -> i ', [2], [2], expected_output=[4]) - - # Test tensors with 0 size dimensions - check('i', [], expected_output=[]) - check(' i j -> j', [[], []], expected_output=[]) - check('ij->i', [[], []], expected_output=[0., 0.]) - check(' i j k , k -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []]) - - # Test broadcasting - check('i,j', [2], [1, 2], expected_output=[[2, 4]]) - check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]]) - - # Test ellipsis broadcasting - check('...', 1, expected_output=1) - check('...->', 1, expected_output=1) - check('...->...', 1, expected_output=1) - check('...', [1], expected_output=[1]) - check('...->', [1], expected_output=1) - check('i...->i', [1], expected_output=[1]) - check('i...->...i', [1], expected_output=[1]) - check('...a->', [[2], [4]], expected_output=6) - check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]]) - - def test_einsum_error_cases(self, device): - def check(equation, operands, regex, exception=RuntimeError): - with self.assertRaisesRegex(exception, r'einsum\(\) ' + regex): - torch.einsum(equation, operands) - - x = torch.rand(2) - y = torch.rand(2, 3) - - check('', [], r'must provide at least one operand') - check('. ..', [x], r'found \'.\' for operand 0 that is not part of any ellipsis') - check('... ...', [x], r'found \'.\' for operand 0 for which an ellipsis was already found') - check('A', [x], r'operand subscript must be in range \[a, z\] but found A for operand 0') - check(',', [x], r'fewer operands were provided than specified in the equation') - check('', [x, x], r'more operands were provided than specified in the equation') - check('', [x], r'the number of subscripts in the equation \(0\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai', [x], r'the number of subscripts in the equation \(2\) does not match the number ' - r'of dimensions \(1\) for operand 0 and no ellipsis was given') - check('ai...', [x], r'the number of subscripts in the equation \(2\) is more than the number ' - r'of dimensions \(1\) for operand 0') - check('a->... .', [x], r'found \'.\' for output but an ellipsis \(...\) was already found') - check('a->..', [x], r'found \'.\' for output that is not part of any ellipsis \(...\)') - check('a->A', [x], r'subscripts must be in range \[a, z\] but found A for the output') - check('a->aa', [x], r'output subscript a appears more than once in the output') - check('a->i', [x], r'output subscript i does not appear in the equation for any input operand') - check('aa', [y], r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') - check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' - r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') - def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, device, dtype): triangle_function = torch.triu if upper else torch.tril @@ -3385,6 +3272,80 @@ def run_test(pivot): if self.device_type == 'cuda': run_test(False) + @onlyCPU + @slowTest + @dtypes(torch.double) + def test_einsum(self, device: torch.device, dtype: torch.dtype) -> None: + # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f + x = torch.randn(5, dtype=dtype, device=device) + y = torch.randn(7, dtype=dtype, device=device) + A = torch.randn(3, 5, dtype=dtype, device=device) + B = torch.randn(2, 5, dtype=dtype, device=device) + C = torch.randn(2, 3, 5, dtype=dtype, device=device) + D = torch.randn(2, 5, 7, dtype=dtype, device=device) + E = torch.randn(7, 9, dtype=dtype, device=device) + F = torch.randn(2, 3, 5, 7, dtype=dtype, device=device) + G = torch.randn(7, 11, 13, dtype=dtype, device=device) + H = torch.randn(4, 4, dtype=dtype, device=device) + I = torch.randn(3, 4, 4, dtype=dtype, device=device) + l = torch.randn(5, 10, dtype=dtype, device=device) + r = torch.randn(5, 20, dtype=dtype, device=device) + w = torch.randn(30, 10, 20, dtype=dtype, device=device) + test_list: List[Union[Tuple[str, torch.Tensor], + Tuple[str, torch.Tensor, torch.Tensor], + Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]]] = [ + # -- Vector + ("i->", x), # sum + ("i,i->", x, x), # dot + ("i,i->i", x, x), # vector element-wise mul + ("i,j->ij", x, y), # outer + # -- Matrix + ("ij->ji", A), # transpose + ("ij->j", A), # row sum + ("ij->i", A), # col sum + ("ij,ij->ij", A, A), # matrix element-wise mul + ("ij,j->i", A, x), # matrix vector multiplication + ("ij,kj->ik", A, B), # matmul + ("ij,ab->ijab", A, E), # matrix outer product + # -- Tensor + ("aij,ajk->aik", C, D), # batch matmul + ("ijk,jk->i", C, A), # tensor matrix contraction + ("aij,jk->aik", D, E), # tensor matrix contraction + ("abcd,dfg->abcfg", F, G), # tensor tensor contraction + ("ijk,jk->ik", C, A), # tensor matrix contraction with double indices + ("ijk,jk->ij", C, A), # tensor matrix contraction with double indices + ("ijk,ik->j", C, B), # non contiguous + ("ijk,ik->jk", C, B), # non contiguous with double indices + # -- Diagonal + ("ii", H), # trace + ("ii->i", H), # diagonal + # -- Ellipsis + ("i...->...", H), + ("ki,...k->i...", A.t(), B), + ("k...,jk", A.t(), B), + ("...ii->...i", I), # batch diagonal + # -- Other + ("bn,anm,bm->ba", l, w, r), # as torch.bilinear + ("... ii->...i ", I), # batch diagonal with spaces + ] + for test in test_list: + actual = torch.einsum(test[0], test[1:]) + expected = np.einsum(test[0], *[t.numpy() for t in test[1:]]) + self.assertEqual(expected.shape, actual.shape, msg=test[0]) + self.assertEqual(expected, actual, msg=test[0]) + # test vararg + actual2 = torch.einsum(test[0], *test[1:]) + self.assertEqual(expected.shape, actual2.shape, msg=test[0]) + self.assertEqual(expected, actual2, msg=test[0]) + + def do_einsum(*args): + return torch.einsum(test[0], args) + # FIXME: following test cases fail gradcheck + if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}: + gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:]) + self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) + self.assertTrue(A._version == 0) # check that we do not use inplace ops + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.double) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 49e0a3cb45c0..75b486043c42 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -368,8 +368,11 @@ def test_inherit_tensor(self): t = torch.zeros(5, 5) p = SubProcess(t.share_memory_()) p.start() - p.join(1) - self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0) + p.join(2) + if p.exitcode is None: + print("test_inherit_tensor: SubProcess too slow") + else: + self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0) @unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing") def test_autograd_errors(self): diff --git a/test/test_nn.py b/test/test_nn.py index a966d6a1f68f..652b4d85cbed 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3587,49 +3587,57 @@ def test_adaptive_pooling_size_none(self): output = module(input) self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,)) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_adaptive_pooling_avg_nhwc(self): - input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32, device="cuda") - input = input.contiguous(memory_format=torch.channels_last).requires_grad_() - grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32, device="cuda") - pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + device_list = ['cpu'] + if TEST_CUDA: + device_list.append('cuda') - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + for device in device_list: + input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) + input = input.contiguous(memory_format=torch.channels_last).requires_grad_() + grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) + pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - out = pool(input) - out.backward(grad) - ref_out = ref_pool(ref_input) - ref_out.backward(ref_grad) + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(input.grad, ref_input.grad) + out = pool(input) + out.backward(grad) + ref_out = ref_pool(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_adaptive_pooling_avg_nhwc_non_contiguous(self): - input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32, device="cuda") - input = input.contiguous(memory_format=torch.channels_last) - input = input[:, ::2, :, :].requires_grad_() - grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32, device="cuda") - grad = grad[:, ::2, :, :] - pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + device_list = ['cpu'] + if TEST_CUDA: + device_list.append('cuda') - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + for device in device_list: + input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) + input = input.contiguous(memory_format=torch.channels_last) + input = input[:, ::2, :, :].requires_grad_() + grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) + grad = grad[:, ::2, :, :] + pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - out = pool(input) - out.backward(grad) - ref_out = ref_pool(ref_input) - ref_out.backward(ref_grad) + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(input.grad, ref_input.grad) + out = pool(input) + out.backward(grad) + ref_out = ref_pool(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @largeTensorTest('12GB', device='cuda') @@ -6479,16 +6487,6 @@ def test_RNN_change_dropout(self): self.assertNotEqual(output2.data, prev_output) prev_output = output1.data - def _verify_pixel_shuffle(self, input, output, upscale_factor): - for c in range(output.size(1)): - for h in range(output.size(2)): - for w in range(output.size(3)): - height_idx = h // upscale_factor - weight_idx = w // upscale_factor - channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ - (c * upscale_factor ** 2) - self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx]) - def test_inplace_thnn(self): modules = [nn.ReLU, nn.ELU, nn.SELU, nn.CELU, nn.RReLU] for mod in modules: @@ -6519,18 +6517,74 @@ def test_noncontig_conv_grad_cuda(self, dtype=torch.float): self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) def test_pixel_shuffle(self): - batch_size = random.randint(1, 3) - upscale_factor = random.randint(2, 5) - channels = random.randint(1, 4) * upscale_factor ** 2 - height = random.randint(5, 10) - width = random.randint(5, 10) - - input = torch.rand(batch_size, channels, height, width, requires_grad=True) - ps = nn.PixelShuffle(upscale_factor) - output = ps(input) - self._verify_pixel_shuffle(input.data, output.data, upscale_factor) - output.backward(output.data) - self.assertEqual(input.data, input.grad.data) + def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True): + # Function to imperatively ensure pixels are shuffled to the correct locations. + # Used to validate the batch operations in pixel_shuffle. + def _verify_pixel_shuffle(input, output, upscale_factor): + for c in range(output.size(-3)): + for h in range(output.size(-2)): + for w in range(output.size(-1)): + height_idx = h // upscale_factor + weight_idx = w // upscale_factor + channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ + (c * upscale_factor ** 2) + self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) + + upscale_factor = random.randint(2, 5) + # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. + channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) + height = random.randint(5, 10) + width = random.randint(5, 10) + + if num_input_dims == 1: + input = torch.rand(channels, requires_grad=True) + elif num_input_dims == 2: + input = torch.rand(height, width, requires_grad=True) + else: + batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] + input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) + ps = nn.PixelShuffle(upscale_factor) + + if num_input_dims >= 3 and valid_channels_dim: + output = ps(input) + _verify_pixel_shuffle(input, output, upscale_factor) + output.backward(output.data) + self.assertEqual(input.data, input.grad.data) + else: + self.assertRaises(RuntimeError, lambda: ps(input)) + + def test_pixel_shuffle_1D(): + _test_pixel_shuffle_helper(num_input_dims=1) + + def test_pixel_shuffle_2D(): + _test_pixel_shuffle_helper(num_input_dims=2) + + def test_pixel_shuffle_3D_with_valid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=3) + + def test_pixel_shuffle_4D_with_valid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=4) + + def test_pixel_shuffle_5D_with_valid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=5) + + def test_pixel_shuffle_3D_with_invalid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False) + + def test_pixel_shuffle_4D_with_invalid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False) + + def test_pixel_shuffle_5D_with_invalid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False) + + test_pixel_shuffle_1D() + test_pixel_shuffle_2D() + test_pixel_shuffle_3D_with_valid_channels_dim() + test_pixel_shuffle_4D_with_valid_channels_dim() + test_pixel_shuffle_5D_with_valid_channels_dim() + test_pixel_shuffle_3D_with_invalid_channels_dim() + test_pixel_shuffle_4D_with_invalid_channels_dim() + test_pixel_shuffle_5D_with_invalid_channels_dim() def test_elu_inplace_view(self): v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True) @@ -12340,7 +12394,6 @@ def test_batchnorm_eval(self, device): self._test_batchnorm_eval(device) @onlyCUDA - @skipCUDAIfNotRocm def test_batchnorm_eval_bfloat16(self, device): self._test_batchnorm_eval(device, torch.bfloat16) diff --git a/test/test_ops.py b/test/test_ops.py index 1be90f2555f8..090232360309 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -231,69 +231,63 @@ def test_variant_consistency_jit(self, device, dtype, op): for sample in samples: # Acquires variants to test + func = op.get_op() method = op.get_method() inplace = op.get_inplace() - variants = (v for v in (method, inplace) if v is not None) - - # Adds function variant to variant list - # TODO: inplace tests currently fail - # variants = (v for v in (op, method, inplace) if v is not None) - variants = (v for v in (op, method) if v is not None) + variants = { + 'function': func, 'method': method, + # TODO: inplace tests currently fail + # 'inplace': inplace, + } # Test traced and scripted consistency - for variant in variants: + for func_type, variant in variants.items(): + if variant is None: + continue + # Create accessor for script function variant - if variant is op: - name = op.name - func_type = 'function' - elif variant is method: - name = op.name - func_type = 'method' - else: # variant is inplace - assert variant is inplace - name = op.name + "_" - func_type = 'inplace' + name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): def fn(*inputs, **kwargs): - attr = getattr(inputs[0], name) - output = attr(*inputs[1:], **kwargs) + output = func(*inputs, **kwargs) return op.output_func(output) # bfloat16 grad doesn't work for some operators dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \ - if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) + if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) # Check scripted forward, grad, and grad grad script_fn = create_script_fn(self, name, func_type, op.output_func) - check_against_reference(self, + check_against_reference(self, script_fn, - fn, - (*sample.input,) + sample.args, - sample.kwargs, + fn, + (*sample.input,) + sample.args, + sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) - check_against_reference(self, + check_against_reference(self, traced_fn, - fn, - (*sample.input,) + sample.args, - sample.kwargs, + fn, + (*sample.input,) + sample.args, + sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) - # Note: only runs in float32 and int64 because schema isn't affected by dtype, + # Note: only runs in float32 and int64 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype in [torch.float32, torch.int32]: - check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs) + check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs, + func_type=func_type, aten_name=op.aten_name) - # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample + # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: diff --git a/test/test_package.py b/test/test_package.py index 894ec8783f1b..ee9140d661b9 100644 --- a/test/test_package.py +++ b/test/test_package.py @@ -119,7 +119,9 @@ def test_resources(self): def test_extern(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: - he.extern_modules(['package_a.subpackage', 'module_a']) + he.extern(['package_a.subpackage', 'module_a']) + he.require_module('package_a.subpackage') + he.require_module('module_a') he.save_module('package_a') hi = PackageImporter(filename) import package_a.subpackage @@ -136,7 +138,7 @@ def test_extern(self): def test_extern_glob(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: - he.extern_modules(['package_a.*', 'module_*']) + he.extern(['package_a.*', 'module_*']) he.save_module('package_a') he.save_source_string('test_module', """\ import package_a.subpackage @@ -158,8 +160,10 @@ def test_extern_glob(self): def test_mock(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: - he.mock_modules(['package_a.subpackage', 'module_a']) + he.mock(['package_a.subpackage', 'module_a']) he.save_module('package_a') + he.require_module('package_a.subpackage') + he.require_module('module_a') hi = PackageImporter(filename) import package_a.subpackage _ = package_a.subpackage @@ -175,7 +179,7 @@ def test_mock(self): def test_mock_glob(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: - he.mock_modules(['package_a.*', 'module*']) + he.mock(['package_a.*', 'module*']) he.save_module('package_a') he.save_source_string('test_module', """\ import package_a.subpackage @@ -199,7 +203,7 @@ def test_custom_requires(self): class Custom(PackageExporter): def require_module(self, name, dependencies): if name == 'module_a': - self.mock_module('module_a') + self.save_mock_module('module_a') elif name == 'package_a': self.save_source_string('package_a', 'import module_a\nresult = 5\n') else: @@ -354,19 +358,22 @@ def load(): self.assertTrue(torch.allclose(*results)) def test_module_glob(self): - from torch.package.exporter import _module_glob_to_re + from torch.package.exporter import _GlobGroup - def check(pattern, should_match, should_not_match): - x = _module_glob_to_re(pattern) + def check(include, exclude, should_match, should_not_match): + x = _GlobGroup(include, exclude) for e in should_match: - self.assertTrue(x.fullmatch(e)) + self.assertTrue(x.matches(e)) for e in should_not_match: - self.assertFalse(x.fullmatch(e)) - - check('torch.*', ['torch.foo', 'torch.bar'], ['tor.foo', 'torch.foo.bar', 'torch']) - check('torch.**', ['torch.foo', 'torch.bar', 'torch.foo.bar'], ['tor.foo', 'torch']) - check('torch.*.foo', ['torch.w.foo'], ['torch.hi.bar.baz']) - check('torch.**.foo', ['torch.w.foo', 'torch.hi.bar.foo'], ['torch.f.foo.z']) + self.assertFalse(x.matches(e)) + + check('torch.*', [], ['torch.foo', 'torch.bar'], ['tor.foo', 'torch.foo.bar', 'torch']) + check('torch.**', [], ['torch.foo', 'torch.bar', 'torch.foo.bar', 'torch'], ['what.torch', 'torchvision']) + check('torch.*.foo', [], ['torch.w.foo'], ['torch.hi.bar.baz']) + check('torch.**.foo', [], ['torch.w.foo', 'torch.hi.bar.foo'], ['torch.f.foo.z']) + check('torch*', [], ['torch', 'torchvision'], ['torch.f']) + check('torch.**', ['torch.**.foo'], ['torch', 'torch.bar', 'torch.barfoo'], ['torch.foo', 'torch.some.foo']) + check('**.torch', [], ['torch', 'bar.torch'], ['visiontorch']) if __name__ == '__main__': main() diff --git a/test/test_profiler.py b/test/test_profiler.py index 797ad0995913..2cd6beaaaf53 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn +import torch.optim +import torch.utils.data from torch.testing._internal.common_utils import ( TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS) from torch.autograd.profiler import profile @@ -14,6 +16,7 @@ HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False +import pickle @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") @@ -129,5 +132,98 @@ def test_kineto(self): self.assertTrue(found_memcpy) # p.export_chrome_trace("/tmp/test_trace.json") + def test_high_level_trace(self): + """Checks that python side high level events are recorded. + """ + class RepeatedDataset(torch.utils.data.Dataset): + def __init__(self, N, D_in, D_out): + self.N = N + self.x = torch.randn(N, D_in) + self.y = torch.randn(N, D_out) + + def __len__(self): + return self.N + + def __getitem__(self, idx): + return self.x, self.y + + class TwoLayerNet(torch.nn.Module): + def __init__(self, D_in, H, D_out): + super(TwoLayerNet, self).__init__() + self.linear1 = torch.nn.Linear(D_in, H) + self.linear2 = torch.nn.Linear(H, D_out) + + def forward(self, x): + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear2(h_relu) + return y_pred + + class CustomSGD(torch.optim.SGD): + def __init__(self, *args, **kwargs): + super(CustomSGD, self).__init__(*args, **kwargs) + + def train(): + for _, data in enumerate(dataloader): + x, y = data[0], data[1] + y_pred = model(x) + loss = criterion(y_pred, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + N, D_in, H, D_out = 8, 10, 5, 2 + model = TwoLayerNet(D_in, H, D_out) + criterion = torch.nn.MSELoss(reduction='sum') + optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + ds = RepeatedDataset(N, D_in, D_out) + dataloader = torch.utils.data.DataLoader(ds, batch_size=1) + + try: + train() + except Exception: + self.assertTrue(False, "Expected no exception without profiling.") + + # Create multiple instances, expect each func is hooked only one time. + # Nested wrappers(repeated patching) will make following test fail. + optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4) + dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1) + + def judge(expected_event_count, prof): + actual_event_count = {} + for e in prof.function_events: + if "#" in e.name: + key = e.name + if key in expected_event_count.keys(): + actual_event_count[key] = actual_event_count.setdefault(key, 0) + 1 + for key, count in expected_event_count.items(): + self.assertTrue((key in actual_event_count.keys()) and (count == actual_event_count[key])) + + with profile() as prof: + train() + expected_event_count = { + # "+1" because the final iteration will enter __next__ but skip the loop body. + "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), + "Optimizer.step#SGD.step": N, + "Optimizer.zero_grad#SGD.zero_grad": N + } + judge(expected_event_count, prof) + + # Test on pickle/unpickle. Expect to work in multi-processing. + optimizer = pickle.loads(pickle.dumps(optimizer)) + with profile() as prof: + train() + judge(expected_event_count, prof) + + # Test on customized optimizer. + optimizer = CustomSGD(model.parameters(), lr=1e-4) + with profile() as prof: + train() + expected_event_count = { + "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), + "Optimizer.step#CustomSGD.step": N, + "Optimizer.zero_grad#CustomSGD.zero_grad": N + } + judge(expected_event_count, prof) + if __name__ == '__main__': run_tests() diff --git a/test/test_sparse.py b/test/test_sparse.py index 72a67caa2038..5af630c0acb4 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -10,7 +10,7 @@ import random import unittest from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_WITH_ROCM, IS_WINDOWS + do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number from torch.autograd.gradcheck import gradcheck @@ -1301,7 +1301,6 @@ def test_spadd_hybrid(self): self._test_spadd_shape(10, [50, 30, 20], [2, 0]) @cuda_only - @unittest.skipIf(not TEST_WITH_ROCM, "runs only on ROCm") def test_sparse_add_out_bfloat16(self): # fp32 x, _, _ = self._gen_sparse(3, 5, 10) diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 8a35e71c035c..6192d6c4d6b6 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -843,8 +843,9 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, else: window = None if expected_error is None: - result = x.stft(n_fft, hop_length, win_length, window, - center=center, return_complex=False) + with self.maybeWarnsRegex(UserWarning, "stft with return_complex=False"): + result = x.stft(n_fft, hop_length, win_length, window, + center=center, return_complex=False) # NB: librosa defaults to np.complex64 output, no matter what # the input dtype ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) @@ -1057,12 +1058,11 @@ def test_complex_stft_onesided(self, device): x.stft(10, window=window, pad_mode='constant', onesided=True) else: y = x.stft(10, window=window, pad_mode='constant', onesided=True, - return_complex=False) - self.assertEqual(y.dtype, torch.double) - self.assertEqual(y.size(), (6, 51, 2)) + return_complex=True) + self.assertEqual(y.dtype, torch.cdouble) + self.assertEqual(y.size(), (6, 51)) - y = torch.rand(100, device=device, dtype=torch.double) - window = torch.randn(10, device=device, dtype=torch.cdouble) + x = torch.rand(100, device=device, dtype=torch.cdouble) with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, pad_mode='constant', onesided=True) @@ -1098,7 +1098,7 @@ def test_fft_input_modification(self, device): def test_istft_round_trip_simple_cases(self, device, dtype): """stft -> istft should recover the original signale""" def _test(input, n_fft, length): - stft = torch.stft(input, n_fft=n_fft, return_complex=False) + stft = torch.stft(input, n_fft=n_fft, return_complex=True) inverse = torch.istft(stft, n_fft=n_fft, length=length) self.assertEqual(input, inverse, exact_dtype=True) @@ -1120,7 +1120,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): for sizes in data_sizes: for i in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) - stft = torch.stft(original, return_complex=False, **stft_kwargs) + stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) # trim the original for case when constructed signal is shorter than original diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 7a622b0e90c6..5c6c4714a518 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -105,6 +105,21 @@ def trivial_graph(a, b, c): s = torch.tensor([[3, 3], [3, 3]]) return a + b * c + s +def loop_graph(a, b, iters : int): + c = a + b * 2 + for i in range(iters): + c = c + b + c *= 2 + c -= a + return c + +def output_graph(a, b, c, iters : int): + s = torch.tensor([[3, 3], [3, 3]]) + k = a + b * c + s + d : Dict[int, Tensor] = {} + for i in range(iters): + d[i] = k + i + return d class TestStaticRuntime(TestCase): def test_multihead_attention_layer(self): @@ -203,5 +218,63 @@ def test_leaky_relu(self): o_test = tg_a(s)[0] torch.testing.assert_allclose(o_ref, o_test) + def test_fusion_trivial_graph(self): + s = torch.full((2, 2), 2) + tg = torch.jit.script(trivial_graph) + o_ref = tg(s, s, s) + torch._C._fuse_to_static_runtime(tg.graph) + assert "StaticSubgraph" in str(tg.graph) + o_test = tg(s, s, s) + torch.testing.assert_allclose(o_ref, o_test) + + def test_fusion_multihead_attention_layer(self): + HID_DIM = 256 + QUERY_LEN = 8 + BATCH_SIZE = 128 + LAYERS = 3 + HEADS = 8 + DROPOUT = 0.1 + device = torch.device("cpu") + attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) + with torch.no_grad(): + src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) + src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) + + attention.eval() + attention = torch.jit.script(attention) + attention.eval() + o_ref = attention(src, src, src, src_mask) + + torch._C._fuse_to_static_runtime(attention._c) + o_test = attention(src, src, src, src_mask) + + for a, b in zip(o_ref, o_test): + torch.testing.assert_allclose(a, b) + + def test_fusion_loop(self): + a = torch.randn(5, 5) + b = torch.randn(5, 5) + c = 4 + lg = torch.jit.script(loop_graph) + o_ref = lg(a, b, c) + torch._C._fuse_to_static_runtime(lg.graph) + assert "StaticSubgraph" in str(lg.graph) + o_test = lg(a, b, c) + torch.testing.assert_allclose(o_ref, o_test) + + def test_fusion_outputs(self): + a = torch.randn(2, 2) + b = torch.randn(2, 2) + c = 4 + og = torch.jit.script(output_graph) + o_ref = og(a, b, b, c) + torch._C._fuse_to_static_runtime(og.graph) + assert "StaticSubgraph" in str(og.graph) + o_test = og(a, b, b, c) + for i in o_ref.keys(): + torch.testing.assert_allclose(o_ref[i], o_test[i]) + + + if __name__ == "__main__": run_tests() diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index b355005b1c69..9be3e6db5bf0 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -14,7 +14,7 @@ IS_WINDOWS) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, - onlyCPU, skipCUDAIfNotRocm, largeTensorTest, precisionOverride, dtypes, + onlyCPU, largeTensorTest, precisionOverride, dtypes, onlyCUDA, skipCPUIf, dtypesIfCUDA, dtypesIfCPU) # TODO: refactor tri_tests_args, _compare_trilu_indices, run_additional_tri_tests @@ -2581,7 +2581,6 @@ def test_arange_device_vs_cpu(self, device, dtype): self.assertEqual(cpu_tensor, device_tensor) @onlyCUDA - @skipCUDAIfNotRocm def test_arange_bfloat16(self, device): ref_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.bfloat16, device=device) bfloat16_tensor = torch.arange(0, 4, dtype=torch.bfloat16, device=device) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 5c30c312534f..6cdccf468326 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -420,11 +420,11 @@ def easy(x, y): traced = torch.jit.trace( easy, (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), - torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8)), + torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), ) a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) - b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8) + b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) x = warmup_and_run_forward(traced, a, b) self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) @@ -1489,6 +1489,39 @@ def simple(a, b): torch._C._jit_set_te_generate_block_code(val) torch._C._jit_texpr_set_fallback_allowed(fall_bk) + def test_strided_output_preserved(self): + def foo(a, b): + return a + b - a + + # smaller, easier to debug example + x = torch.arange(6) + x = torch.as_strided(x, (2, 3), (1, 2)) + total = 0 + for i in range(2): + for j in range(3): + x[i, j] = total + total += 1 + foo_script = torch.jit.script(foo) + foo_script(x, x) + foo_script(x, x) + out_s = foo_script(x, x) + out_eager = foo(x, x) + self.assertEqual(out_s, out_eager) + self.assertEqual(out_s.stride(), out_eager.stride()) + self.assertLastGraphAllFused() + + # more dims + N, C, H, W, = 2, 3, 4, 5 + x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last) + foo_script = torch.jit.script(foo) + foo_script(x, x) + foo_script(x, x) + out_s = foo_script(x, x) + out_eager = foo(x, x) + self.assertEqual(out_s, out_eager) + self.assertEqual(out_s.stride(), out_eager.stride()) + self.assertLastGraphAllFused() + def test_alias_analysis_module(self): class AliasModule(nn.Module): def __init__(self): @@ -1595,6 +1628,26 @@ def getModule(script): torch.testing.assert_allclose(ref, test) + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_multiple_outputs(self): + # A bug reported internally similar to the one reported in #48533 + def foo(a, b, c): + t_next = c + 1 + t5 = t_next * b + t6 = torch.unsqueeze(t_next, 1) + t7 = a * t6 + return (t7, t5, t_next) + + a = torch.rand(20, 20, dtype=torch.float32, device='cuda') + b = torch.rand(20 * 29, dtype=torch.float32, device='cuda').as_strided([20], [29]) + c = torch.ones(20, dtype=torch.int64, device='cuda') + traced = torch.jit.trace(foo, (a, b, c)) + ref = foo(a, b, c) + exp = traced(a, b, c) + exp = traced(a, b, c) + for i in range(3): + assert(torch.allclose(ref[i], exp[i])) + if __name__ == '__main__': unittest.main() diff --git a/test/test_testing.py b/test/test_testing.py index 4a1b33831a44..b87345186cb3 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -432,6 +432,12 @@ def test_isclose_atol_rtol_greater_than_zero(self, device, dtype): with self.assertRaises(RuntimeError): torch.isclose(t, t, atol=-1, rtol=-1) + def test_assert_messages(self, device): + self.assertIsNone(self._get_assert_msg(msg=None)) + self.assertEqual("\nno_debug_msg", self._get_assert_msg("no_debug_msg")) + self.assertEqual("no_user_msg", self._get_assert_msg(msg=None, debug_msg="no_user_msg")) + self.assertEqual("debug_msg\nuser_msg", self._get_assert_msg(msg="user_msg", debug_msg="debug_msg")) + instantiate_device_type_tests(TestTesting, globals()) if __name__ == '__main__': diff --git a/test/test_torch.py b/test/test_torch.py index fde60ca4174f..b4d9ad6f23c0 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -19,15 +19,15 @@ from torch import multiprocessing as mp from torch.testing._internal.common_utils import ( TestCase, TEST_WITH_ROCM, run_tests, - IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, + IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, - skipIfRocm, skipIfNoSciPy, + skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, - skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, + skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, @@ -254,8 +254,8 @@ def get_tensor(size, dtype, device, contiguous): height = 5 width = 5 for device in torch.testing.get_all_device_types(): - for dt1 in torch.testing.get_all_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): - for dt2 in torch.testing.get_all_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): + for dt1 in torch.testing.get_all_dtypes(): + for dt2 in torch.testing.get_all_dtypes(): for contiguous in [True, False]: x1 = get_tensor((height, width), dt1, device, contiguous) x2 = get_tensor((height, width), dt2, device, contiguous) @@ -341,9 +341,6 @@ def test_device(self): self.assertEqual(90, cuda90.index) self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu:1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1)) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1)) self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 ')) self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2')) @@ -356,7 +353,6 @@ def test_device(self): self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3')) - self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1)) self.assertRaises(RuntimeError, lambda: torch.device(-1)) self.assertRaises(RuntimeError, lambda: torch.device('other')) @@ -1856,15 +1852,14 @@ def test_storage_casts(self): self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') self.assertIs(complexdouble_storage.dtype, torch.complex128) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.FloatStorage.from_file(f.name, True, size) + def assert_with_filename(filename): + size = 10000 + s1 = torch.FloatStorage.from_file(filename, True, size) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping - s2 = torch.FloatStorage.from_file(f.name, True, size) + s2 = torch.FloatStorage.from_file(filename, True, size) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, atol=0, rtol=0) @@ -1878,15 +1873,24 @@ def test_from_file(self): t2.fill_(rnum) self.assertEqual(t1, t2, atol=0, rtol=0) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") + # release the tensors + del s1, t1, s2, t2 + + with TemporaryFileName() as fname: + assert_with_filename(fname) + + if IS_FILESYSTEM_UTF8_ENCODING: + with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname: + assert_with_filename(fname) + def test_torch_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.from_file(f.name, True, size, dtype=torch.float) + def assert_with_filename(filename): + size = 10000 + s1 = torch.from_file(filename, True, size, dtype=torch.float) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping - s2 = torch.from_file(f.name, True, size, dtype=torch.float) + s2 = torch.from_file(filename, True, size, dtype=torch.float) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, atol=0, rtol=0) @@ -1900,6 +1904,16 @@ def test_torch_from_file(self): t2.fill_(rnum) self.assertEqual(t1, t2, atol=0, rtol=0) + # release the tensors + del s1, t1, s2, t2 + + with TemporaryFileName() as fname: + assert_with_filename(fname) + + if IS_FILESYSTEM_UTF8_ENCODING: + with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname: + assert_with_filename(fname) + def test_print(self): default_type = torch.Tensor().type() for t in torch._tensor_classes: @@ -4994,6 +5008,7 @@ def test_ternary_op_mem_overlap(self, device, dtype): expected_failure=not has_input_output_mem_overlap_check) @dtypes(torch.double) + @onlyOnCPUAndCUDA def test_copy_mem_overlap(self, device, dtype): self.check_internal_mem_overlap( torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device) @@ -5002,14 +5017,49 @@ def test_copy_mem_overlap(self, device, dtype): self.unary_check_input_output_mem_overlap( doubles, sz, lambda input, out: out.copy_(input)) + @onlyOnCPUAndCUDA def test_index_add_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) y = torch.rand((6,), device=device) - ind = torch.tensor([0, 2, 3], device=device) + ind = torch.tensor([2, 1, 0], device=device) value = torch.rand((3,), device=device) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x.index_add_(0, ind, value) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.index_add_(0, ind, y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_add_(0, ind, ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_add_(0, ind.clone(), ind) + @onlyOnCPUAndCUDA + def test_index_copy_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((6,)) + y = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device) + value = torch.rand((3,), device=device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + x.index_copy_(0, ind, value) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.index_copy_(0, ind, y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_copy_(0, ind, ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_copy_(0, ind.clone(), ind) + + @onlyOnCPUAndCUDA + def test_index_fill_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((6,)) + y = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device) + value = torch.rand((3,), device=device) + + with self.assertWarnsRegex(UserWarning, "index_fill_ on expanded tensors"): + x.index_fill_(0, ind, 1.0) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_fill_(0, ind, 0) + + @onlyOnCPUAndCUDA def test_shift_mem_overlap(self, device): x = torch.rand(3, device=device) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -5017,6 +5067,7 @@ def test_shift_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x[:-1] >>= x[1:] + @onlyOnCPUAndCUDA def test_bernoulli_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) @@ -5030,16 +5081,26 @@ def test_bernoulli_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.bernoulli(torch.rand_like(x), out=x) + @onlyOnCPUAndCUDA def test_index_put_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) y = torch.rand((6,), device=device) - ind = torch.tensor([0, 2, 3], device=device) + ind = torch.tensor([2, 1, 0], device=device) value = torch.rand((3,), device=device) with self.assertWarnsRegex(UserWarning, 'expanded tensors'): x.index_put_((ind,), value) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): y.index_put_((ind,), y[0]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_put_((ind,), ind) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.index_put_((ind,), y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_put_((ind,), ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_put_((ind.clone(),), ind) + @onlyOnCPUAndCUDA def test_masked_fill_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) mask = torch.tensor([True, False, True, True, False, False], device=device) @@ -5050,13 +5111,22 @@ def test_masked_fill_mem_overlap(self, device): with self.assertWarnsRegex(UserWarning, 'expanded tensors'): x.masked_fill_(mask, fill_val) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + mask[1:].masked_fill_(mask[:-1], False) + + @onlyOnCPUAndCUDA def test_masked_select_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((3,)) y = torch.rand((6,), device=device) mask = torch.tensor([True, False, True, True, False, False], device=device) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.masked_select(y, mask, out=x) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.masked_select(y, mask, out=y) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.masked_select(mask.clone(), mask, out=mask) + @onlyOnCPUAndCUDA def test_masked_scatter_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) src = torch.rand((3,), device=device) @@ -5065,6 +5135,7 @@ def test_masked_scatter_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x.masked_scatter_(mask, src) + @onlyOnCPUAndCUDA def test_index_select_mem_overlap(self, device): x = torch.rand((1, 6), device=device).expand((2, 6)) y = torch.rand((3, 6), device=device) @@ -5072,20 +5143,43 @@ def test_index_select_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.index_select(y, 1, ind, out=x) + @onlyOnCPUAndCUDA def test_scatter_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) src = torch.rand((3,), device=device) - ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x.scatter_(0, ind, src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + src.scatter_(0, ind, src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.scatter_(0, ind, ind.clone()) + @onlyOnCPUAndCUDA def test_gather_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((3,)) src = torch.rand((6,), device=device) - ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.gather(src, 0, ind, out=x) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.gather(src, 0, ind, out=src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.gather(ind.clone(), 0, ind[1:], out=ind[:1]) + + @onlyOnCPUAndCUDA + def test_take_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.take(src, ind, out=x) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.take(src, ind, out=src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.take(ind.clone(), ind[1:], out=ind[:-1]) + @onlyCUDA def test_multinomial_device_constrain(self, device): @@ -6094,7 +6188,7 @@ def _where_valid_scalar_tensor_combination(self, scalar_type, dtype): return False @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) def test_where_scalar_invalid_combination_raises(self, device, dtype): @@ -6106,7 +6200,7 @@ def checkRaises(scalar_type, dtype, condition, x, scalar_1): self._test_where_scalar_template(device, dtype, checkRaises) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) def test_where_scalar_valid_combination(self, device, dtype): @@ -6162,7 +6256,6 @@ class TestDevicePrecision(TestCase): exact_dtype = True @onlyCUDA - @skipCUDAIfNotRocm def test_index_add_bfloat16(self, device): inp_tensor = torch.randn(5, 3, device='cpu').bfloat16() t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu') @@ -6316,10 +6409,6 @@ def test_copy_broadcast(self, device) -> None: torch.uint8 ] -# _types2 adds bfloat16 type to _types only on ROCm. Should eventually be unified -# with _types when bfloat16 bringup is complete on all platforms. -_types2 = _types + [torch.bfloat16] if TEST_WITH_ROCM else _types - _float_types = [torch.half, torch.float, torch.double] _complex_types = [torch.cfloat, torch.cdouble] @@ -6328,10 +6417,6 @@ def test_copy_broadcast(self, device) -> None: _float_types_no_half = [torch.float, torch.double] -# _float_types2 adds bfloat16 type to _float_types only on ROCm. Should eventually be unified -# with _float_types when bfloat16 bringup is complete on all platforms -_float_types2 = _float_types + [torch.bfloat16] if TEST_WITH_ROCM else _float_types - _signed_types = [ torch.half, torch.bfloat16, torch.float, torch.double, torch.int8, torch.short, torch.int, torch.long @@ -6605,10 +6690,14 @@ def inner(self, device, dtype): ('dot', '', _medium_1d, lambda t, d: [_medium_1d(t, d)], 1e-2, 1e-5, 1e-5, _float_types + _complex_types, _cpu_types, False), ('element_size', '', _medium_1d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False), - ('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types2), + ('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), ('equal', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('equal', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), @@ -6622,10 +6711,14 @@ def inner(self, device, dtype): ('lcm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 0, 0, 0, [torch.int16, torch.int32, torch.int64], [torch.int16, torch.int32, torch.int64], True, [onlyOnCPUAndCUDA]), - ('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), + ('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), ('is_contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), # TODO: can't check negative case - cross-device copy is contiguous ('is_same_size', 'negative', _medium_2d, lambda t, d: [_small_3d(t, d)], @@ -6689,10 +6782,12 @@ def inner(self, device, dtype): ('narrow', '', _small_3d, lambda t, d: [1, 3, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('narrow', 'neg_dim', _small_3d, lambda t, d: [-1, 3, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('nonzero', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('norm', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), - ('norm', '3_norm', _small_3d, lambda t, d: [3], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), - ('norm', '3_norm_dim', _small_3d, lambda t, d: [3, 0], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), - ('norm', '3_norm_neg_dim', _small_3d, lambda t, d: [3, -2], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), + ('norm', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('norm', '3_norm', _small_3d, lambda t, d: [3], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('norm', '3_norm_dim', _small_3d, lambda t, d: [3, 0], 1e-1, 1e-1, 1e-5, + torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('norm', '3_norm_neg_dim', _small_3d, lambda t, d: [3, -2], 1e-1, 1e-1, 1e-5, + torch.testing.get_all_fp_dtypes(), _cpu_types, False), ('new_ones', '', _small_3d, lambda t, d: [1, 2, 3, 4, 5], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('permute', '', _new_t((1, 2, 3, 4)), lambda t, d: [2, 1, 3, 0], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('put_', '', _new_t((2, 5, 3)), @@ -6707,12 +6802,16 @@ def inner(self, device, dtype): torch.LongTensor([[1], [2]]).to(dtype=_convert_t(t, d), device=d), True], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('prod', '', lambda t, d: _small_2d(t, d, oneish=True), - lambda t, d: [], 1e-2, 1e-1, 1e-5, _types2, _cpu_types, False), - ('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-1, 1e-5, _types2, _cpu_types, False), - ('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-1, 1e-5, _types2, _cpu_types, False), - ('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _types2, _cpu_types, False), - ('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types2, _cpu_types, False), + ('prod', '', lambda t, d: _small_2d(t, d, oneish=True), lambda t, d: [], 1e-2, 1e-1, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-1, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-1, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), ('sum', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, 1e-5, _types, _cpu_types, False), ('sum', 'complex', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False), ('sum', 'complex_dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False), @@ -6796,7 +6895,7 @@ def inner(self, device, dtype): ('geqrf', '', _new_t((20, 20)), lambda t, d: [], 1e-5, 1e-5, 3e-4, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]), ('eig', 'with_eigvec', _new_t((10, 10)), lambda t, d: [True], - 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]), + 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma, onlyOnCPUAndCUDA]), ('abs', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), [torch.bfloat16]), ('sign', '', _small_3d, lambda t, d: []), @@ -6819,9 +6918,6 @@ def inner(self, device, dtype): ('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()), ('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('expm1', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types), - ('expm1', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1), - lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 656845598a49..f08c5341b399 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1770,7 +1770,6 @@ def _medium_2d(dtype, device): # TODO: all these should be replaced with OpInfos torch_op_tests = [ _TorchMathTestMeta('exp'), - _TorchMathTestMeta('expm1'), _TorchMathTestMeta('floor'), _TorchMathTestMeta('ceil'), _TorchMathTestMeta('rad2deg'), diff --git a/test/test_vmap.py b/test/test_vmap.py index 9192c00a94d3..5fa8426fd4ab 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -1365,6 +1365,37 @@ def test_expand_as(self): test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) + def test_fill_and_zero_inplace(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B0, B1 = 7, 11 + ops = ( + lambda t: t.fill_(0.1), + lambda t: t.fill_(torch.tensor(0.2)), + lambda t: t.zero_(), + ) + + for op in ops: + # Single vmap, various in_dims / out_dims + test(op, [TensorFactory.randn([B0, 3])]) + test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) + test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [TensorFactory.randn([B0, B1])]) + test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) + test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + # test when value is a batched tensor for fill_ operator + B0, B1 = 3, 5 + test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) + + with self.assertRaisesRegex(RuntimeError, + r"output with shape .+ doesn't match the broadcast shape"): + # Runtime Error is thrown when the tensor being written to isn't being vmapped over + vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]), + TensorFactory.randn([B0])) + def _test_complex_views(self, op, dtypes): test = self._vmap_view_test diff --git a/third_party/NNPACK b/third_party/NNPACK index 24b55303f5cf..57616b9a0ef7 160000 --- a/third_party/NNPACK +++ b/third_party/NNPACK @@ -1 +1 @@ -Subproject commit 24b55303f5cf65d75844714513a0d1b1409809bd +Subproject commit 57616b9a0ef7b0f8e56bfe7e9738744b52fe1828 diff --git a/third_party/kineto b/third_party/kineto index bf384310eafa..e9198dd3066e 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit bf384310eafa674a1cc83be71f52ecb320ccdf84 +Subproject commit e9198dd3066ee6e5e20201d6ae6f86f092bb7123 diff --git a/third_party/tensorpipe b/third_party/tensorpipe index 82a114882e21..5381c57ba923 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit 82a114882e21b176916e2f12a7b566af3d63df71 +Subproject commit 5381c57ba923481ffaf7c40f9acc7f164ded887f diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index 66c7b1c7a1ab..45b99e64ec9a 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -93,7 +93,13 @@ TENSORPIPE_HEADERS = glob([ TENSORPIPE_BASE_SRCS = glob([ "tensorpipe/*.cc", "tensorpipe/channel/*.cc", - "tensorpipe/common/*.cc", + "tensorpipe/common/address.cc", + "tensorpipe/common/epoll_loop.cc", + "tensorpipe/common/error.cc", + "tensorpipe/common/fd.cc", + "tensorpipe/common/ibv.cc", + "tensorpipe/common/socket.cc", + "tensorpipe/common/system.cc", "tensorpipe/core/*.cc", "tensorpipe/transport/*.cc", "tensorpipe/util/*/*.cc", @@ -107,7 +113,10 @@ TENSORPIPE_SRCS = TENSORPIPE_BASE_SRCS + glob([ ]) TENSORPIPE_SRCS_CUDA = TENSORPIPE_SRCS + glob([ + "tensorpipe/common/cuda_loop.cc", + "tensorpipe/channel/cuda_basic/*.cc", "tensorpipe/channel/cuda_ipc/*.cc", + "tensorpipe/channel/cuda_xth/*.cc", ]) cc_library( diff --git a/tools/README.md b/tools/README.md index 5f915d510f86..b940d378320b 100644 --- a/tools/README.md +++ b/tools/README.md @@ -24,11 +24,16 @@ Build system pieces: * [setup_helpers](setup_helpers) - Helper code for searching for third-party dependencies on the user system. * [build_pytorch_libs.py](build_pytorch_libs.py) - cross-platform script that - builds all of the constituent libraries of PyTorch, + builds all of the constituent libraries of PyTorch, but not the PyTorch Python extension itself. * [build_libtorch.py](build_libtorch.py) - Script for building libtorch, a standalone C++ library without Python support. This build script is tested in CI. +* [fast_nvcc](fast_nvcc) - Mostly-transparent wrapper over nvcc that + parallelizes compilation when used to build CUDA files for multiple + architectures at once. + * [fast_nvcc.py](fast_nvcc/fast_nvcc.py) - Python script, entrypoint to the + fast nvcc wrapper. Developer tools which you might find useful: @@ -52,8 +57,6 @@ Important if you want to run on AMD GPU: Tools which are only situationally useful: -* [aten_mirror.sh](aten_mirror.sh) - Mirroring script responsible - for keeping https://github.com/zdevito/ATen up-to-date. * [docker](docker) - Dockerfile for running (but not developing) PyTorch, using the official conda binary distribution. Context: https://github.com/pytorch/pytorch/issues/1619 diff --git a/tools/aten_mirror.sh b/tools/aten_mirror.sh deleted file mode 100755 index 6c787bbda568..000000000000 --- a/tools/aten_mirror.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/sh - -# This script is run by a cronjob managed by @zdevito -# which mirrors the ATen-specific directories of PyTorch -# to zdevito/ATen, for ease of use of projects that wish -# to depend solely on ATen. -# -# See also .travis.aten.yml, which is the Travis configuration -# for the ATen project (and ensures ATen is separately -# buildable.) - -if [[ -z "$EXTRACTED_REPO" ]]; then - echo "Need to set envvar EXTRACTED_REPO" - exit 1 -fi -if [[ -z "$FULL_REPO" ]]; then - echo "Need to set envvar FULL_REPO" - exit 1 -fi -rm -rf aten-export-repo -git clone $EXTRACTED_REPO aten-export-repo -cd aten-export-repo -git config user.name "Zach DeVito" -git config user.email "zdevito@fb.com" -git remote add fullrepo $FULL_REPO -git fetch fullrepo -git checkout -b temporary-split-branch fullrepo/master -# Cribbed from https://stackoverflow.com/questions/2982055/detach-many-subdirectories-into-a-new-separate-git-repository -# and https://stackoverflow.com/questions/42355621/git-filter-branch-moving-a-folder-with-index-filter-does-not-work -git filter-branch -f --index-filter 'git rm --cached -qr --ignore-unmatch -- . && git reset -q $GIT_COMMIT -- aten cmake third_party/tbb third_party/catch third_party/cpuinfo && (git ls-files -s | sed "s-.travis.aten.yml-.travis.yml-" | sed "s-.gitmodules.aten-.gitmodules-" | git update-index --index-info)' -git checkout master -git merge temporary-split-branch -git push diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index c393c905c73f..943d9adab4a0 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -52,7 +52,7 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None: @with_native_function def gen_annotated_args(f: NativeFunction) -> str: out_args: List[Dict[str, Any]] = [] - for arg in f.func.arguments.positional: + for arg in f.func.arguments.flat_positional: if arg.default is not None: continue out_arg: Dict[str, Any] = {} diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 123a47f2aac2..570c99908853 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -193,25 +193,28 @@ def load_signatures( deprecated_yaml_path: str, *, method: bool, + skip_deprecated: bool = False, + pyi: bool = False, ) -> Sequence[PythonSignatureNativeFunctionPair]: native_functions = list(filter(should_generate_py_binding, parse_native_yaml(native_yaml_path))) @with_native_function def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: return PythonSignatureNativeFunctionPair( - signature=signature(f, method=method), + signature=signature(f, method=method, pyi=pyi), function=f, ) pairs = list(map(gen_signature_pairs, native_functions)) - deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method) - return pairs + deprecated + deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method, pyi=pyi) + return pairs if skip_deprecated else pairs + deprecated def load_deprecated_signatures( pairs: Sequence[PythonSignatureNativeFunctionPair], deprecated_yaml_path: str, *, method: bool, + pyi: bool, ) -> List[PythonSignatureNativeFunctionPair]: # The deprecated.yaml doesn't have complete type information, we need # find and leverage the original ATen signature (to which it delegates @@ -225,6 +228,8 @@ def signature_original(f: NativeFunction) -> str: opname = str(f.func.name.name.base) if f.func.is_out_fn(): opname += '_out' + if f.func.name.name.inplace and pyi: + opname += '_' args = CppSignatureGroup.from_schema(f.func, method=False).signature.arguments() # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml. types = ', '.join(argument_type_str(a.argument.type) @@ -308,6 +313,7 @@ def signature_deprecated(opname: str, params: List[str], call_args: List[str]) - method=python_sig.method, deprecated_args_names=tuple(args), deprecated_args_exprs=tuple(call_args), + returns=python_sig.returns, ), function=pair.function, )) @@ -320,31 +326,10 @@ def signature_deprecated(opname: str, params: List[str], call_args: List[str]) - # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# TODO: remove the copy of this method in 'tools/pyi/gen_pyi.py'. -@with_native_function -def namedtuple_fieldnames(f: NativeFunction) -> List[str]: - returns = f.func.returns - if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)): - return [] - else: - if any(map(lambda r: r.name is None, returns)): - # When building on Windows, `PyStructSequence_UnnamedField` could not be - # resolved by the linker for some reason, which cause error in building: - # - # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol - # PyStructSequence_UnnamedField - # - # Thus, at this point in time, we do not support unnamed - # fields in namedtuple; you must either name all fields, - # or none of them. - raise ValueError("Unnamed field is not supported by codegen") - - return list(map(lambda r: str(r.name), returns)) - @with_native_function def gen_namedtuple_typename_key(f: NativeFunction) -> str: name = cpp.name(f.func) - fieldnames = namedtuple_fieldnames(f) + fieldnames = namedtuple_fieldnames(f.func.returns) return '_'.join([name] + fieldnames) def emit_namedtuple_typedefs( @@ -360,7 +345,7 @@ def emit_namedtuple_typedefs( typedefs: List[str] = [] # typedef declarations and init code for overload in overloads: - fieldnames = namedtuple_fieldnames(overload.function) + fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue @@ -651,7 +636,7 @@ def method_def( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # def group_overloads( - overloads: Sequence[PythonSignatureNativeFunctionPair] + overloads: Sequence[PythonSignatureNativeFunctionPair], ) -> Sequence[PythonSignatureGroup]: bases: Dict[str, PythonSignatureNativeFunctionPair] = {} outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index e55402f9e68d..6bc83b9716e6 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -126,11 +126,17 @@ def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequen args = [cpp_args.argument for cpp_args in sig_group.signature.arguments()] if f.func.is_out_fn(): - # *_out functions take the result as a first argument, but they are the - # last argument in the JIT schema. + # *_out functions take the result as a separate argument, but we don't want to + # trace that argument directly. Instead, we trace its TensorOptions. + # So first, we need to remove the out argument from the list of arguments to trace. # TODO: byte-for-byte compatible with old codegen behavior - it's incorrect to assume # there is only one output argument. - args = args[1:] + if f.use_c10_dispatcher.dispatcher_uses_new_style(): + # for c10-full ops, the out argument is in the end + args = args[:-1] + else: + # for legacy ops, the out argument is in the beginning. + args = args[1:] trace_inputs = itertools.chain.from_iterable(dispatch_trace_input(arg) for arg in args) @@ -144,8 +150,7 @@ def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequen # Factories are a bit special because their out-of-place overloads # take an extra TensorOptions argument, which is missing in the _out function has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) - has_tensor_input_arg = any(a.type.is_tensor_like() - for a in itertools.chain(f.func.arguments.positional, f.func.arguments.kwarg_only)) + has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out) is_factory_method = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg) # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 6c2bfbe4dcaa..a17e222f8cf1 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -23,7 +23,7 @@ # differentiable subcomponents. # -from .utils import CodeTemplate, nested_dict, write +from .utils import CodeTemplate, nested_dict, write, make_out_api_name_faithful from .gen_autograd import VIEW_FUNCTIONS, VIEW_FUNCTIONS_WITH_METADATA_CHANGE, \ MULTI_OUTPUT_SAFE_FUNCTIONS, RETURNS_VIEWS_OF_INPUT from .gen_autograd_functions import uses_single_grad @@ -78,7 +78,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', '_fft_c2c', '_fft_r2c', } @@ -615,8 +615,13 @@ def save_variables( def emit_dispatch_call(api_name, input_base, unpacked_args): """ Dispatch call via function in a namespace or method on Tensor.""" if 'namespace' in declaration['method_of']: + if declaration['use_c10_dispatcher'] in ['hacky_wrapper_for_legacy_signatures', 'full']: + dispatcher_api_name = make_out_api_name_faithful(api_name) + else: + assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' + dispatcher_api_name = api_name call = CALL_DISPATCH_VIA_NAMESPACE.substitute( - api_name=api_name, + api_name=dispatcher_api_name, unpacked_args=unpacked_args) else: call = CALL_DISPATCH_VIA_METHOD.substitute( @@ -698,8 +703,9 @@ def wrap_output(return_values, var): creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE" else: creation_meta = "CreationMeta::MULTI_OUTPUT_NODE" - rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, " - "/* creation_meta */ {})").format(view_info, var, creation_meta) + call += ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, " + "/* creation_meta */ {});\n").format(view_info, var, creation_meta) + rhs_value = 'std::move({})'.format(var) else: call += emit_view_lambda() creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE" diff --git a/tools/autograd/utils.py b/tools/autograd/utils.py index 5c0fcccc4c78..86758b5b3ff3 100644 --- a/tools/autograd/utils.py +++ b/tools/autograd/utils.py @@ -1,8 +1,8 @@ import re import os import yaml -from collections import defaultdict from .nested_dict import nested_dict +from typing import Dict, List __all__ = [ @@ -47,12 +47,18 @@ def split_name_params(prototype): def uninplace_api_name(api_name): if api_name.endswith('_') and not api_name.endswith('__'): api_name = api_name[:-1] + return unout_api_name(api_name) + +def make_out_api_name_faithful(api_name): + # Variable kernel needs to call the _outf overload instead of the _out overload + # because the _outf overload matches the argument order as it's passed into + # the variable kernel if api_name.endswith('_out'): - api_name = api_name[:-4] + api_name = api_name + 'f' return api_name -def write(dirname, name, template, env): +def write(dirname: str, name: str, template: CodeTemplate, env: Dict[str, List[str]]) -> None: env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename) path = os.path.join(dirname, name) # See Note [Unchanging results for ninja] @@ -69,12 +75,6 @@ def write(dirname, name, template, env): else: print("Skipped writing {}".format(path)) -def is_tensor_method(declaration): - return 'Tensor' in declaration['method_of'] - -def is_torch_function(declaration): - return 'namespace' in declaration['method_of'] - def is_out_variant(decl): return decl['name'].endswith('_out') @@ -92,12 +92,6 @@ def load_op_list_and_strip_overload(op_list, op_list_path): # strip out the overload part return {opname.split('.', 1)[0] for opname in op_list} -def group_declarations_by_op_name(declarations): - groups = defaultdict(list) - for d in declarations: - groups[op_name(d)].append(d) - return groups - def is_output(arg): return arg.get('output', False) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 7e5a5e4e7f8a..eca10839ae88 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -96,6 +96,7 @@ core_sources_common = [ "torch/csrc/jit/runtime/jit_exception.cpp", "torch/csrc/jit/runtime/operator.cpp", "torch/csrc/jit/runtime/print_handler.cpp", + "torch/csrc/jit/runtime/slice_indices_adjust.cpp", "torch/csrc/jit/runtime/register_ops_utils.cpp", "torch/csrc/jit/runtime/vararg_functions.cpp", "torch/csrc/jit/serialization/unpickler.cpp", @@ -175,7 +176,6 @@ core_sources_full_mobile = [ "torch/csrc/jit/passes/erase_number_types.cpp", "torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp", "torch/csrc/jit/passes/freeze_module.cpp", - "torch/csrc/jit/passes/reconstruct_scopes.cpp", "torch/csrc/jit/passes/fuse_linear.cpp", "torch/csrc/jit/passes/fuse_relu.cpp", "torch/csrc/jit/passes/graph_fuser.cpp", @@ -266,8 +266,10 @@ core_sources_full_mobile = [ ] core_sources_full = core_sources_full_mobile + [ + "torch/csrc/jit/runtime/static/fusion.cpp", "torch/csrc/jit/runtime/static/impl.cpp", "torch/csrc/jit/runtime/static/ops.cpp", + "torch/csrc/jit/runtime/static/passes.cpp", ] libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources) @@ -470,6 +472,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/python_comm.cpp", "torch/csrc/cuda/Storage.cpp", "torch/csrc/cuda/Stream.cpp", + "torch/csrc/cuda/Graph.cpp", "torch/csrc/cuda/serialization.cpp", "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index b20497b5a82c..f2f6edb88983 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -23,10 +23,14 @@ # BTW: policy on name collisions: we try not to have types with # collisions, but functions are fair game to collide -def name(func: FunctionSchema) -> str: +def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: name = str(func.name.name) if func.is_out_fn(): - name += '_out' + if faithful_name_for_out_overloads: + name += '_outf' + else: + name += '_out' + return name # Translation of "value types" in JIT schema to C++ API type. Value @@ -252,14 +256,21 @@ def argument_not_this( def argument( a: Union[Argument, TensorOptionsArguments, SelfArgument], + *, + method: bool, ) -> Union[CppSingleArgumentPack, CppThisArgumentPack]: if isinstance(a, SelfArgument): - return CppThisArgumentPack(argument=a, type=argument_type(a.argument)) + if method: + return CppThisArgumentPack(argument=a, type=argument_type(a.argument)) + else: + return CppSingleArgumentPack(argument_not_this(a.argument)) else: return CppSingleArgumentPack(argument_not_this(a)) def argument_faithful( a: Union[Argument, TensorOptionsArguments, SelfArgument], + *, + method: bool, ) -> CppArgumentPack: if isinstance(a, TensorOptionsArguments): return CppTensorOptionsArgumentPack( @@ -270,22 +281,4 @@ def argument_faithful( pin_memory=argument_not_this(a.pin_memory), ) else: - return argument(a) - -def group_arguments( - func: FunctionSchema, *, method: bool -) -> Sequence[Union[Argument, TensorOptionsArguments, SelfArgument]]: - args: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] - args.extend(func.arguments.out) - args.extend(func.arguments.pre_self_positional) - if func.arguments.self_arg is not None: - if method: - args.append(func.arguments.self_arg) - else: - args.append(func.arguments.self_arg.argument) - args.extend(func.arguments.post_self_positional) - args.extend(func.arguments.pre_tensor_options_kwarg_only) - if func.arguments.tensor_options is not None: - args.append(func.arguments.tensor_options) - args.extend(func.arguments.post_tensor_options_kwarg_only) - return args + return argument(a, method=method) diff --git a/tools/codegen/api/dispatcher.py b/tools/codegen/api/dispatcher.py index 8f3925de0041..165b68e3a830 100644 --- a/tools/codegen/api/dispatcher.py +++ b/tools/codegen/api/dispatcher.py @@ -6,7 +6,7 @@ import tools.codegen.local as local import itertools -from typing import Sequence, Optional, Tuple +from typing import Sequence, Optional, Tuple, List, Union # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through @@ -68,7 +68,11 @@ def name(func: FunctionSchema) -> str: def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]: if local.use_c10_dispatcher().dispatcher_uses_new_style(): - return tuple(map(argument, itertools.chain(func.arguments.out, func.arguments.positional, func.arguments.kwarg_only))) + return tuple(map(argument, itertools.chain( + func.arguments.flat_positional, + func.arguments.flat_kwarg_only, + func.arguments.out + ))) else: return tuple( DispatcherArgument(type=la.type, name=la.name, argument=la.argument) @@ -137,7 +141,29 @@ def cppargument_exprs( else: assert_never(a) -def cpparguments_exprs(args: Sequence[CppArgumentPack]) -> Sequence[DispatcherExpr]: +def cpparguments_exprs(func: FunctionSchema, * , method: bool, api_is_faithful: bool) -> Sequence[DispatcherExpr]: + dispatcher_is_faithful = local.use_c10_dispatcher().dispatcher_uses_new_style() + + arguments: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + if dispatcher_is_faithful: + arguments.extend(func.arguments.non_out) + arguments.extend(func.arguments.out) + else: + arguments.extend(func.arguments.out) + arguments.extend(func.arguments.non_out) + + if api_is_faithful: + argument_packs = tuple( + cpp.argument_faithful(a, method=method) for a in arguments + ) + else: + argument_packs = tuple( + cpp.argument(a, method=method) for a in arguments + ) + + return _cpparguments_exprs(argument_packs) + +def _cpparguments_exprs(args: Sequence[CppArgumentPack]) -> Sequence[DispatcherExpr]: tensor_options = next( (a.this for a in args if isinstance(a, CppSingleArgumentPack) and isinstance(a.this.argument, TensorOptionsArguments)), @@ -148,13 +174,13 @@ def cpparguments_exprs(args: Sequence[CppArgumentPack]) -> Sequence[DispatcherEx # I don't think this is entirely sound, but it should be reasonably # close def nativearguments_exprs(args: Sequence[NativeArgument]) -> Sequence[DispatcherExpr]: - return cpparguments_exprs([ + return _cpparguments_exprs([ CppSingleArgumentPack(CppArgument(type=a.type, name=a.name, default=None, argument=a.argument)) for a in args ]) def exprs(args: Sequence[DispatcherArgument]) -> Sequence[DispatcherExpr]: - return cpparguments_exprs([ + return _cpparguments_exprs([ CppSingleArgumentPack(CppArgument(type=a.type, name=a.name, default=None, argument=a.argument)) for a in args ]) diff --git a/tools/codegen/api/meta.py b/tools/codegen/api/meta.py index 4bfc8e837ec1..6beee3eaefbb 100644 --- a/tools/codegen/api/meta.py +++ b/tools/codegen/api/meta.py @@ -1,11 +1,9 @@ from tools.codegen.model import * from tools.codegen.api.types import MetaArgument -import tools.codegen.api.cpp as cpp import tools.codegen.api.dispatcher as dispatcher from typing import Sequence -import itertools # Follows dispatcher calling convention, but: # - Mutable arguments not allowed. Meta functions are always @@ -13,40 +11,14 @@ # - No tensor returns; instead we return a TensorMeta describing # the tensor in question -def name(f: FunctionSchema) -> str: - assert f.name.overload_name == "" - return str(f.name.name) +def name(g: StructuredNativeFunctions) -> str: + # use the overload name from the functional version + return str(g.functional.func.name).replace('.', '_') def argument_type(a: Argument) -> str: assert not a.is_write return dispatcher.argumenttype_type(a.type, mutable=False) -def returntype_type(t: Type) -> str: - r = cpp.valuetype_type(t) - if r is not None: - return r - - if isinstance(t, BaseType): - if t.name == BaseTy.Tensor: - return 'TensorMeta' - elif isinstance(t, ListType): - raise NotImplementedError("list returns not supported yet") - - raise AssertionError(f"unrecognized return type {t}") - -def return_type(r: Return) -> str: - assert not r.is_write - return returntype_type(r.type) - -def returns_type(rs: Sequence[Return]) -> str: - if len(rs) == 0: - return 'void' - elif len(rs) == 1: - return return_type(rs[0]) - else: - args = ','.join(map(return_type, rs)) - return f'std::tuple<{args}>' - def argument(a: Argument) -> MetaArgument: return MetaArgument( type=argument_type(a), @@ -56,4 +28,4 @@ def argument(a: Argument) -> MetaArgument: def arguments(func: FunctionSchema) -> Sequence[MetaArgument]: assert not func.arguments.out - return list(map(argument, itertools.chain(func.arguments.positional, func.arguments.kwarg_only))) + return list(map(argument, func.arguments.flat_non_out)) diff --git a/tools/codegen/api/native.py b/tools/codegen/api/native.py index b9e5257aef85..7ae0325ec324 100644 --- a/tools/codegen/api/native.py +++ b/tools/codegen/api/native.py @@ -4,7 +4,7 @@ import tools.codegen.api.cpp as cpp from tools.codegen import local -from typing import Union, Sequence, Tuple +from typing import Union, Sequence, Tuple, List # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the @@ -105,4 +105,11 @@ def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> Sequen assert_never(a) def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]: - return tuple(i for arg in cpp.group_arguments(func, method=False) for i in argument(arg)) + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + if local.use_c10_dispatcher() is UseC10Dispatcher.full: + args.extend(func.arguments.non_out) + args.extend(func.arguments.out) + else: + args.extend(func.arguments.out) + args.extend(func.arguments.non_out) + return tuple(i for arg in args for i in argument(arg)) diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index 4b407d45553a..c78fe23150e8 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -1,4 +1,3 @@ -import itertools from dataclasses import dataclass from typing import Optional, Union, Sequence, Set, List, Tuple, Dict @@ -14,6 +13,8 @@ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # +# [Notes] python binding codegen +# # The Python binding codegen produces code that takes the input list of # PyObjects, finds the matching ATen C++ function using PythonArgParser, # converts the PyObjects into C++ types and calls the ATen C++ function: @@ -172,6 +173,43 @@ # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); # } # +# +# [Notes] python interface codegen +# The python dataclasses below are used used to generate both python binding code +# and pyi type hint signatures. +# In theory these two should look very similar, but there are number of differences +# in how pyi signatures vs. python_arg_parser signatures are generated. +# These differences have been encapsulated in signature_str() vs. signature_str_pyi() +# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. +# For examples, only pyi signatures include return types. + +@dataclass(frozen=True) +class PythonReturns: + returns: Tuple[Return, ...] + + def named_tuple_pyi(self) -> Optional[Tuple[str, str]]: + python_returns = [argument_type_str_pyi(r.type) for r in self.returns] + field_names = namedtuple_fieldnames(self.returns) + if field_names: + namedtuple_name = '_'.join(['namedtuple'] + field_names) + tuple_args = [f'("{name}", {typ})' for name, typ in zip(field_names, python_returns)] + namedtuple_def = f'NamedTuple("{namedtuple_name}", [{", ".join(tuple_args)}])' + return namedtuple_name, namedtuple_def + return None + + def returns_str_pyi(self) -> str: + named_tuple = self.named_tuple_pyi() + if named_tuple is not None: + namedtuple_name, _ = named_tuple + return namedtuple_name + + python_returns = [argument_type_str_pyi(r.type) for r in self.returns] + if len(python_returns) > 1: + return 'Tuple[' + ', '.join(python_returns) + ']' + if len(python_returns) == 1: + return python_returns[0] + return 'None' + @dataclass(frozen=True) class PythonArgument: @@ -192,10 +230,10 @@ class PythonArgument: def argument_str(self, *, method: bool = False) -> str: type_str = argument_type_str(self.type) + name = self.name # s/self/input/ outside method bindings # [old codegen] TODO: remove this? doesn't rename in codegen, it's just # for the parse string - name = self.name if name == 'self' and type_str == 'Tensor' and not method: name = 'input' @@ -210,6 +248,43 @@ def argument_str(self, *, method: bool = False) -> str: else: return f'{type_str} {name}' + def argument_str_pyi(self, *, method: bool = False, deprecated: bool = False) -> str: + type_str = argument_type_str_pyi(self.type) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == 'self' and type_str == 'Tensor' and not method and not deprecated: + name = 'input' + + if name == 'from': # from is a Python keyword... + name += '_' + + # pyi merges the _out and functional variants into the same signature, with an optional out arg + if name == 'out' and type_str == 'Tensor' and not deprecated: + type_str = 'Optional[' + type_str + ']' + + # pyi deprecated signatures don't get defaults for their out arg + treat_as_no_default = deprecated and isinstance(self, PythonOutArgument) and self.default == 'None' + + # add default + if self.default is not None and not treat_as_no_default: + if isinstance(self.type, ListType) and self.type.elem == BaseType(BaseTy.int) and \ + self.default.startswith('{') and self.default.endswith('}'): + default = '(' + self.default[1:-1] + ')' + else: + default = { + 'nullptr': 'None', + 'c10::nullopt': 'None', + '{}': 'None', + 'MemoryFormat::Contiguous': 'contiguous_format', + 'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine', + }.get(self.default, self.default) + return f'{name}: {type_str}={default}' + else: + return f'{name}: {type_str}' + @dataclass(frozen=True) class PythonOutArgument(PythonArgument): # In Python signature multiple output fields are packed into one 'out' argument. @@ -238,6 +313,7 @@ def from_outputs(outputs: Tuple[PythonArgument, ...]) -> Optional['PythonOutArgu raise RuntimeError(f'Unsupported output type: {outputs}') return PythonOutArgument( name='out', + # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? type=ListType(BaseType(BaseTy.Tensor), size), default='None', default_init=None, @@ -260,6 +336,9 @@ class PythonSignature: output_args: Optional[PythonOutArgument] + # Return types, which are only used by pyi + returns: PythonReturns + # These are scattered kwargs arguments belonging to TensorOptions. # When binding to C++, they are packed into a TensorOptions object 'options'. # It's possible that the C++ signature doesn't take TensorOptions object (e.g. @@ -301,18 +380,56 @@ def output_idx(self) -> int: # for error parsing. # # For a translation to mypy-valid type signatures, see - # tools/gen_pyi.py. If you change any logic here, please - # check that file too. + # signature_str_pyi(). def signature_str(self, *, skip_outputs: bool = False) -> str: - schema_formals: List[str] = \ - list(map(lambda a: a.argument_str(method=self.method), - self.arguments(skip_outputs=skip_outputs))) + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method), args)) positional_argc = len(self.input_args) if len(schema_formals) > positional_argc: schema_formals.insert(positional_argc, '*') return f'{self.name}({", ".join(schema_formals)})' + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, '*') + + # only pyi signatures include returns + returns_str = self.returns.returns_str_pyi() + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: + # only pyi uses vararg signatures + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) + # vararg only applies to pyi signatures. vararg variants are not generated for all signatures + num_args = self.arguments_count() + num_positionalargs = len(self.input_args) + + have_vararg_version = False + if num_args > 0: + vararg_type = args[0].type + if isinstance(vararg_type, ListType) and str(vararg_type.elem) == 'int' and num_positionalargs == 1: + have_vararg_version = True + + if not have_vararg_version: + return None + # Below are the major changes in vararg vs. regular pyi signatures + # vararg signatures also omit the asterix + schema_formals[0] = '*' + args[0].name + ': _int' + + returns_str = self.returns.returns_str_pyi() + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + # The deprecated python signature involves some special logic, so create a # dedicated data model to store these extra properties. @dataclass(frozen=True) @@ -340,6 +457,20 @@ def deprecated(self) -> bool: def signature_str(self, *, skip_outputs: bool = False) -> str: return PythonSignature.signature_str(self, skip_outputs=skip_outputs) + '|deprecated' + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)) + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, '*') + + returns_str = self.returns.returns_str_pyi() + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: + # the codegen doesn't include vararg variants for deprecated signatures + return None + # This struct is used to hold the PythonSignature and its corresponding # NativeFunction BEFORE grouping base and out-variant functions. # Why not store NativeFunction in PythonSignature or construct PythonSignature @@ -438,8 +569,7 @@ def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: return CppSignatureGroup.from_schema(f.func, method=method).signature def has_tensor_options(f: NativeFunction) -> bool: - return any(filter(lambda a: isinstance(a, TensorOptionsArguments), - cpp.group_arguments(f.func, method=False))) + return f.func.arguments.tensor_options is not None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # @@ -490,6 +620,8 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: return f'IntArrayRef[{size}]' if size is not None else 'IntArrayRef' elif str(t.elem) == 'Tensor': return f'TensorList[{size}]' if size is not None else 'TensorList' + elif str(t.elem) == 'Scalar': + return f'ScalarList[{size}]' if size is not None else 'ScalarList' elif str(t.elem) == 'Tensor?': if simple_type: return 'TensorList' @@ -520,15 +652,22 @@ def argument(a: Argument) -> PythonArgument: default_init=None, ) -def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature: - # Use cpp api to gather TensorOptions fields from kwargs. - # Skip ThisArgument if this is method signature. - # Skip TensorOptionsArguments in C++ signature. Python side TensorOptions +# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen +def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> PythonSignature: + args: List[Argument] = [] + args.extend(f.func.arguments.pre_self_positional) + # Skip SelfArgument if this is method. + if not method and f.func.arguments.self_arg is not None: + args.append(f.func.arguments.self_arg.argument) + args.extend(f.func.arguments.post_self_positional) + args.extend(f.func.arguments.pre_tensor_options_kwarg_only) + # Skip TensorOptionsArguments. Python side TensorOptions # arguments are created based on different rules - see below. - args = tuple(a for a in cpp.group_arguments(f.func, method=method) if isinstance(a, Argument)) + args.extend(f.func.arguments.post_tensor_options_kwarg_only) + args.extend(f.func.arguments.out) - input_arg_set = set(a.name for a in f.func.arguments.positional) - kwarg_only_set = set(a.name for a in f.func.arguments.kwarg_only) + input_arg_set = set(a.name for a in f.func.arguments.flat_positional) + kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) out_arg_set = set(a.name for a in f.func.arguments.out) input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) @@ -543,8 +682,7 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature: # to the original versions in the yaml, this recreation is a potential # source of drift between eager and JIT. Pull this logic out to a shared place. - has_tensor_input_arg = any(a.type.is_tensor_like() - for a in itertools.chain(f.func.arguments.positional, f.func.arguments.kwarg_only)) + has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out) if any(a.name == 'requires_grad' for a in f.func.schema_order_arguments()): raise ValueError('argument named requires_grad is reserved, should not explicitly add it in the schema') @@ -561,13 +699,13 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature: tensor_options_args.append(PythonArgument( name='dtype', type=BaseType(BaseTy.ScalarType), - default=_dtype_default_type_hack(name), + default='None' if pyi else _dtype_default_type_hack(name), default_init='self.scalar_type()' if is_like_or_new_function else None, )) tensor_options_args.append(PythonArgument( name='layout', type=OptionalType(BaseType(BaseTy.Layout)), - default='torch.strided', + default='strided' if pyi else 'torch.strided', default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None, )) tensor_options_args.append(PythonArgument( @@ -589,21 +727,107 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature: default_init=None, )) + returns = PythonReturns(returns=f.func.returns) + return PythonSignature( name=str(f.func.name.name), input_args=input_args, input_kwargs=input_kwargs, output_args=PythonOutArgument.from_outputs(outputs), tensor_options_args=tuple(tensor_options_args), + returns=returns, method=method, ) # TODO blowtorch +# note: removing this will be BC-breaking. A quick test shows that +# randperm will otherwise default its dtype to torch.float64 def _dtype_default_type_hack(name: str) -> str: if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': return 'torch.int64' else: return 'None' +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Interface +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]: + if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)): + return [] + else: + if any(map(lambda r: r.name is None, returns)): + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in namedtuple; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + + return list(map(lambda r: str(r.name), returns)) + +def argument_type_str_pyi(t: Type) -> str: + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + + if isinstance(t, BaseType): + if t.name == BaseTy.int: + ret = '_int' + elif t.name == BaseTy.float: + ret = '_float' + elif t.name == BaseTy.str: + ret = 'str' + elif t.name == BaseTy.Scalar: + ret = 'Number' + elif t.name == BaseTy.ScalarType: + ret = '_dtype' + elif t.name == BaseTy.bool: + ret = '_bool' + elif t.name == BaseTy.QScheme: + ret = '_qscheme' + elif t.name == BaseTy.Layout: + ret = '_layout' + elif t.name == BaseTy.Device: + ret = 'Union[_device, str, None]' + elif t.name == BaseTy.MemoryFormat: + ret = 'memory_format' + elif t.name == BaseTy.Dimname: + ret = 'Union[str, ellipsis, None]' + elif t.name in [BaseTy.Tensor, BaseTy.Generator, + BaseTy.Storage, BaseTy.Stream, BaseTy.str]: + # These python schema type names line up with their function schema names + ret = t.name.name + + elif isinstance(t, ListType): + if str(t.elem) == 'int': + ret = 'Union[_int, _size]' if t.size is not None else '_size' + elif t.is_tensor_like(): + # TODO: this doesn't seem right... + # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] + # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] + if isinstance(t.elem, OptionalType): + add_optional = True + ret = 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]' if t.size is not None else \ + 'Union[Tuple[Tensor, ...], List[Tensor]]' + elif str(t.elem) == 'float': + ret = 'Sequence[float]' + else: + elem = argument_type_str_pyi(t.elem) + ret = f'Sequence[{elem}]' + + if add_optional: + ret = 'Optional[' + ret + ']' + return ret + + raise RuntimeError(f'unrecognized type {repr(t)}') + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # @@ -841,7 +1065,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str: return 'intlist' elif str(t) == 'float[]': return 'doublelist' - + elif str(t) == 'Scalar[]': + return 'scalarlist' raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser') # Return RHS expression for python argument using PythonArgParser output. diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index 32caf26f223f..55a6e4abc52c 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -1,6 +1,6 @@ from tools.codegen.model import * from dataclasses import dataclass -from typing import Optional, Union, Sequence, Tuple, TypeVar +from typing import Optional, Union, Sequence, Tuple, TypeVar, List _T = TypeVar('_T') @@ -68,7 +68,7 @@ class CppSingleArgumentPack(CppArgumentPackIface): this: CppArgument def no_default(self) -> 'CppSingleArgumentPack': - return CppSingleArgumentPack(self.this.no_default()) + return CppSingleArgumentPack(this=self.this.no_default()) @property def type(self) -> str: @@ -150,68 +150,67 @@ class CppSignature: # The schema this signature is derived from func: FunctionSchema - # Enough information about the C++ types to generate a full - # C++ type signature for this signature. I'm not too sure - # if these are the right representations, so for now this - # is intended to be more abstract. - _argument_packs: Tuple[CppArgumentPack, ...] - _returns_type: str + # Is this a C++ signature for a method, i.e. Tensor::my_op(...)? + method: bool + + # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API + # (i.e. with a potential TensorOptions argument and out arguments in the front) + faithful: bool + + fallback_binding: bool = False # Return the unpacked argument structure of this signature, # discarding information about which arguments are semantically # related to each other. def arguments(self) -> Sequence[CppArgument]: - return [sub_a for a in self._argument_packs for sub_a in a.explicit_arguments()] + return [sub_a for a in self.argument_packs() for sub_a in a.explicit_arguments()] # Return the packed argument structure of this signature. This preserves # high-level structure of the arguments so you may find it easier to do # translations working with this representation. def argument_packs(self) -> Sequence[CppArgumentPack]: - return self._argument_packs - - # Render the C++ declaration for this signature - def decl(self) -> str: - cpp_args_str = ', '.join(map(str, self.arguments())) - return f"{self._returns_type} {cpp.name(self.func)}({cpp_args_str})" - - # Render the C++ definition for this signature, not including - # the body (with curly braces) - def defn(self, name: Optional[str] = None, *, prefix: str = "") -> str: - cpp_args_str = ', '.join(a.str_no_default() for a in self.arguments()) - if name is None: - name = prefix + cpp.name(self.func) - return f"{self._returns_type} {name}({cpp_args_str})" + arguments: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + if self.faithful: + arguments.extend(self.func.arguments.non_out) + arguments.extend(self.func.arguments.out) + else: + arguments.extend(self.func.arguments.out) + arguments.extend(self.func.arguments.non_out) - # NB: This constructor knows how to disambiguate defaults when - # faithful is True. Ideally this would live as an external process - # see https://github.com/pytorch/pytorch/pull/45666 - @staticmethod - def _from_grouped_arguments( - func: FunctionSchema, - arguments: Sequence[Union[Argument, TensorOptionsArguments, SelfArgument]], - *, - faithful: bool - ) -> 'CppSignature': - if faithful: - # Faithful signatures will ungroup arguments into argument - # packs. - # + if self.faithful: # After this, manually do overload disambiguation, by # dropping defaults from the faithful signature. In # principle, we should be able to do this at some later # point in time with other overload disambiguation argument_packs = tuple( - cpp.argument_faithful(a).no_default() for a in arguments + cpp.argument_faithful(a, method=self.method).no_default() for a in arguments ) else: argument_packs = tuple( - cpp.argument(a) for a in arguments + cpp.argument(a, method=self.method) for a in arguments ) - return CppSignature( - func=func, - _argument_packs=argument_packs, - _returns_type=cpp.returns_type(func.returns), - ) + return argument_packs + + def name(self) -> str: + n = cpp.name(self.func, faithful_name_for_out_overloads=self.faithful) + if self.fallback_binding: + n = f"__dispatch_{n}" + return n + + # Render the C++ declaration for this signature + def decl(self) -> str: + returns_type = cpp.returns_type(self.func.returns) + cpp_args_str = ', '.join(map(str, self.arguments())) + return f"{returns_type} {self.name()}({cpp_args_str})" + + # Render the C++ definition for this signature, not including + # the body (with curly braces) + def defn(self, *, prefix: str = "") -> str: + returns_type = cpp.returns_type(self.func.returns) + cpp_args_str = ', '.join(a.str_no_default() for a in self.arguments()) + name = prefix + self.name() + return f"{returns_type} {name}({cpp_args_str})" + # Represents group of all CppSignatures associated with a # FunctionSchema. Right now, that's the regular, user-visible @@ -224,14 +223,13 @@ class CppSignatureGroup: faithful_signature: Optional[CppSignature] @staticmethod - def from_schema(func: FunctionSchema, *, method: bool) -> 'CppSignatureGroup': - grouped_arguments = cpp.group_arguments(func, method=method) + def from_schema(func: FunctionSchema, *, method: bool, fallback_binding: bool = False) -> 'CppSignatureGroup': faithful_signature: Optional[CppSignature] - if any(isinstance(a, TensorOptionsArguments) for a in grouped_arguments): - faithful_signature = CppSignature._from_grouped_arguments(func, grouped_arguments, faithful=True) + if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: + faithful_signature = CppSignature(func=func, faithful=True, method=method, fallback_binding=fallback_binding) else: faithful_signature = None - signature = CppSignature._from_grouped_arguments(func, grouped_arguments, faithful=False) + signature = CppSignature(func=func, faithful=False, method=method, fallback_binding=fallback_binding) return CppSignatureGroup( func=func, signature=signature, @@ -357,6 +355,10 @@ def defn(self, name: Optional[str] = None) -> str: name = self.name() return f"{self._returns_type} {name}({args_str})" + def ptr_type(self) -> str: + args_str = ', '.join(map(str, self.arguments())) + return f'{self._returns_type} (*)({args_str})' + def arguments(self) -> Tuple[NativeArgument, ...]: return self._arguments @@ -385,7 +387,7 @@ class MetaArgument: type: str name: str # structured kernels (for which MetaArgument matters) always will - # be use_c10_dispatcher full. That means JIT arguments and + # be use_c10_dispatcher full. That means JIT arguments and # meta arguments are always in 1:1 correspondence. If this is ever not true # we will have to do something more fancy here. argument: Argument diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 4db060acd401..9ad4099d9196 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -188,6 +188,11 @@ def is_generic_dispatch_key(dk: str) -> bool: def is_cuda_dispatch_key(dk: str) -> bool: return 'CUDA' in dk +# Structured kernel generation is only supported for certain key types; +# otherwise use old-style +def is_structured_dispatch_key(dk: str) -> bool: + return dk in {'CUDA', 'CPU'} + # Generates RegisterSchema.cpp. Depending on the selector, either # all schemas are registered, or only some are (in the case of # selective build) @@ -230,6 +235,9 @@ class RegisterDispatchKey: # registration code for. selector: SelectiveBuilder + # Whether or not we are actually code-genning for ROCm + rocm: bool + def __post_init__(self) -> None: assert self.target is not Target.DECLARATION @@ -243,6 +251,126 @@ def __call__(self, f: Union[StructuredNativeFunctions, NativeFunction]) -> List[ else: assert_never(f) + def gen_structured_class_set_output(self, k: SchemaKind, parent_class: str, generate_super: bool) -> str: + if generate_super: + set_output_super = f"{parent_class}::set_output(output_idx, sizes, strides, options, names);" + else: + set_output_super = "" + return f""" +void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, + TensorOptions options, DimnameList names) override {{ + {self.gen_structured_class_set_output_body(k)} + if (!names.empty()) namedinference::propagate_names(outputs_[output_idx], names); + // super must happen after, so that downstream can use maybe_get_output + // to retrieve the output + {set_output_super} +}} +""" + + def gen_structured_class_set_output_body(self, k: SchemaKind) -> str: + if self.dispatch_key == 'CUDA': + maybe_set_guard = """ +auto current_device = guard_.current_device(); +if (C10_UNLIKELY(current_device.has_value())) { + TORCH_INTERNAL_ASSERT(*current_device == options.device(), + "structured kernels don't support multi-device outputs"); +} else { + guard_.set_device(options.device()); +} +""" + else: + maybe_set_guard = '' + + if k is SchemaKind.functional: + if self.dispatch_key == "Meta": + return """ +if (strides.empty()) { + outputs_[output_idx] = at::empty_meta(sizes, options); +} else { + TORCH_INTERNAL_ASSERT(0, "not implemented yet"); +} +""" + else: + expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ + "options.device_opt(), options.pinned_memory_opt()" + if self.dispatch_key == "CPU": + empty_impl = "at::native::empty_cpu" + empty_strided_impl = "at::native::empty_strided_cpu" + elif self.dispatch_key == "CUDA": + empty_impl = "at::native::empty_cuda" + empty_strided_impl = "at::native::empty_strided_cuda" + else: + raise AssertionError("unsupported dispatch key") + return f""" +{maybe_set_guard} +if (strides.empty()) {{ + outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); +}} else {{ + outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts}); +}} +""" + elif k is SchemaKind.inplace: + return maybe_set_guard + elif k is SchemaKind.out: + return f""" +{maybe_set_guard} +at::native::resize_output(outputs_[output_idx], sizes); +if (!strides.empty()) {{ + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + at::native::as_strided_(outputs_[output_idx], sizes, strides); +}} else if (options.memory_format_opt().has_value()) {{ + outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); +}} +""" + else: + assert_never(k) + + # returns the definition of a ctor, as well as how to construct + # this class to a variable named op + def gen_structured_class_ctor(self, k: SchemaKind, class_name: str) -> str: + if k is SchemaKind.functional: + return "" + elif k is SchemaKind.inplace: + # TODO: Make sure out argument is guaranteed to be self + return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" + elif k is SchemaKind.out: + # TODO: Stop hardcoding out here + return f"{class_name}(Tensor& out) : outputs_{{std::ref(out)}} {{}}" + else: + assert_never(k) + + def gen_structured_class( + self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool + ) -> str: + if k is SchemaKind.functional: + assert len(f.func.returns) == 1, "multi-return not supported yet" + output_type = "Tensor" + elif k is SchemaKind.inplace: + output_type = "std::reference_wrapper" + elif k is SchemaKind.out: + assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" + output_type = "std::reference_wrapper" + + if self.dispatch_key == 'CUDA': + if self.rocm: + guard_field = 'c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;' + else: + guard_field = 'c10::cuda::OptionalCUDAGuard guard_;' + else: + guard_field = '' + + return f""" +struct {class_name} final : public {parent_class} {{ + {self.gen_structured_class_ctor(k, class_name)} + {self.gen_structured_class_set_output(k, parent_class, generate_super)} + const Tensor& maybe_get_output(int64_t output_idx) override {{ + return outputs_[output_idx]; + }} + std::array<{output_type}, {len(f.func.returns)}> outputs_; + {guard_field} +}}; +""" + def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: if self.dispatch_key == 'Meta': assert self.dispatch_key not in g.out.dispatch, \ @@ -250,6 +378,8 @@ def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: "functions, they will be automatically generated for you" elif self.dispatch_key not in g.out.dispatch: return [] + elif not is_structured_dispatch_key(self.dispatch_key): + return list(mapMaybe(self.gen_unstructured, g.functions())) # Inner helper function to close over g # TODO: This function has a lot of similarity with gen_unstructured. If @@ -261,7 +391,6 @@ def gen_one(f: NativeFunction) -> Optional[str]: # TODO: put this into StructuredNativeFunctions itself functional_func = g.out.func.signature() functional_sig = DispatcherSignature.from_schema(functional_func) - meta_name = meta.name(functional_func) # This is a little abusive; this assumes that the functionalization # transformation ALWAYS refers to valid arguments in the original @@ -276,74 +405,75 @@ def gen_one(f: NativeFunction) -> Optional[str]: sig = NativeSignature.from_schema(f.func) if self.target is Target.DEFINITION: - # TODO: work a little harder to generate fresh names for 'result' - # TODO: less praying that I picked the right argument name for 'self' + if self.dispatch_key == 'Meta': + class_name = f"structured_{meta.name(g)}_meta_{k.name}" + parent_class = f"at::meta::{meta.name(g)}" + else: + class_name = f"structured_{g.out.dispatch[self.dispatch_key]}_{k.name}" + parent_class = f"at::native::structured_{g.out.dispatch[self.dispatch_key]}" if k is SchemaKind.functional: - out_expr = "result" - if self.dispatch_key == "Meta": - prologue = "auto result = meta_tensor_from_meta(meta_result);" - else: - prologue = "auto result = tensor_from_meta(meta_result);" + assert len(f.func.returns) == 1, "multi-return not supported yet" + out_expr = "op.outputs_[0]" + ret_expr = "std::move(op.outputs_[0])" # small optimization + op_init = f"{class_name} op;" elif k is SchemaKind.inplace: out_expr = "self" - prologue = "// TODO: consistency check assert" + ret_expr = "self" + op_init = f"{class_name} op(self);" elif k is SchemaKind.out: - # TODO: generalize this for multi-out assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" - # TODO: properly get the expression as it was brought into - # scope by sig out_expr = f.func.arguments.out[0].name - prologue = f""" -// TODO: add a consistency check for meta_result -{out_expr}.resize_(meta_result.sizes); -""" + ret_expr = out_expr + op_init = f"{class_name} op({out_expr});" - if self.dispatch_key == "Meta": - out_impl_call = "// meta function does nothing" + if self.dispatch_key == 'Meta': + impl_call = "" else: - out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}" - out_impl_call = f"{out_impl_name}({out_expr}, {functional_exprs});" - - device_guard = "" - - if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key): - # TODO: avoid copypasting the computation of self_args, - # candidate_args and device_of - self_args = (a for a in f.func.arguments.positional if a.name == "self") - candidate_args = itertools.chain(self_args, f.func.arguments.out, f.func.arguments.positional) - device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) - - device_guard = '' - if f.device_guard and device_of is not None: - # TODO: Use OptionalCUDAGuard when possible - device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" - # TODO: figure out what to do about structured kernels and - # factory functions + impl_call = f"op.impl({out_expr}, {functional_exprs});" # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ +{self.gen_structured_class( + f, k, + class_name=class_name, + parent_class=parent_class, + generate_super=g.out.structured_inherits is not None +)} + {sig.defn()} {{ - {device_guard} - auto meta_result = meta::{meta_name}({functional_exprs}); - {prologue} - {out_impl_call} - return {out_expr}; + {op_init} + op.meta({functional_exprs}); + {impl_call} + return {ret_expr}; }} """ elif self.target is Target.REGISTRATION: + dispatcher_sig = DispatcherSignature.from_schema(f.func) + if local.use_c10_dispatcher() is UseC10Dispatcher.full: - payload = f'TORCH_FN({sig.name()})' + payload = f"TORCH_FN({sig.name()})" + elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: + payload = f""" +c10::impl::hacky_wrapper_for_legacy_signatures< + {dispatcher_sig.type()}, + {len(f.func.arguments.out)} +>(TORCH_FN({sig.name()})) +""" else: - payload = f'torch::CppFunction::makeUnboxedOnly({sig.name()})' + assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper + payload = f"torch::CppFunction::makeUnboxedOnly(&{sig.name()})" return f'm.impl("{f.func.name}", {payload});' else: assert_never(self.target) + # Silence mypy's "Missing return statement" error + return None return list(mapMaybe(gen_one, g.functions())) + @method_with_native_function def gen_unstructured(self, f: NativeFunction) -> Optional[str]: # for mypy type refinement; would be fixed by TODO on target assert self.target is not Target.DECLARATION @@ -369,11 +499,15 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]: cuda_guard = "" if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key): - self_args = (a for a in f.func.arguments.positional if a.name == "self") + self_arg = [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else [] # There is precedence for which argument we use to do # device guard. This describes the precedence order. - candidate_args = itertools.chain(self_args, f.func.arguments.out, f.func.arguments.positional) + candidate_args = itertools.chain( + self_arg, + f.func.arguments.out, + f.func.arguments.flat_positional + ) # Only tensor like arguments are eligible device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) @@ -425,9 +559,12 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]: if local.use_c10_dispatcher() is UseC10Dispatcher.full: payload = f"TORCH_FN({name})" elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: - payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \ - f"{dispatcher_sig.type()}>(TORCH_FN({name}))" - + payload = f""" +c10::impl::hacky_wrapper_for_legacy_signatures< + {dispatcher_sig.type()}, + {len(f.func.arguments.out)} +>(TORCH_FN({name})) +""" else: assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper payload = f"torch::CppFunction::makeUnboxedOnly(&{name})" @@ -452,7 +589,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]: name = cpp.name(f.func) - sig_group = CppSignatureGroup.from_schema(f.func, method=False) + sig_group = CppSignatureGroup.from_schema(f.func, method=False, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"CAFFE2_API {sig_group.signature.decl()};\n" @@ -462,10 +599,15 @@ def __call__(self, f: NativeFunction) -> Optional[str]: assert self.target is Target.DEFINITION - def generate_defn(sig: CppSignature) -> str: + def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) - dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs()) + if faithful and sig_group.faithful_signature is not None: + sig = sig_group.faithful_signature + else: + sig = sig_group.signature + + dispatcher_exprs = dispatcher.cpparguments_exprs(f.func, method=False, api_is_faithful=faithful) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) return f""" @@ -478,10 +620,9 @@ def generate_defn(sig: CppSignature) -> str: }} """ - result = generate_defn(sig_group.signature) + result = generate_defn(sig_group.faithful_signature is None) if sig_group.faithful_signature is not None: - if local.use_c10_dispatcher().dispatcher_uses_new_style(): - result += generate_defn(sig_group.faithful_signature) + result += generate_defn(True) return result @@ -498,12 +639,11 @@ def __call__(self, f: NativeFunction) -> Optional[str]: return None assert not f.func.is_out_fn() - assert len(f.func.arguments.positional) > 0 - assert sum(a.name == 'self' for a in f.func.arguments.positional) == 1 + assert f.func.arguments.self_arg is not None name = cpp.name(f.func) - sig_group = CppSignatureGroup.from_schema(f.func, method=True) + sig_group = CppSignatureGroup.from_schema(f.func, method=True, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" @@ -513,10 +653,16 @@ def __call__(self, f: NativeFunction) -> Optional[str]: assert self.target is Target.DEFINITION - def generate_defn(sig: CppSignature) -> str: + def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) - dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs()) + if faithful: + sig = sig_group.faithful_signature + assert sig is not None + else: + sig = sig_group.signature + + dispatcher_exprs = dispatcher.cpparguments_exprs(f.func, method=True, api_is_faithful=faithful) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) return f""" @@ -529,9 +675,9 @@ def generate_defn(sig: CppSignature) -> str: }} """ - result = generate_defn(sig_group.signature) + result = generate_defn(faithful=False) if sig_group.faithful_signature is not None: - result += generate_defn(sig_group.faithful_signature) + result += generate_defn(faithful=True) return result @@ -547,32 +693,73 @@ def compute_aten_op(f: NativeFunction) -> str: # Generates NativeFunctions.h, a list of forward declarations of all # actual kernel definitions we keep in aten/src/ATen/native/ @with_native_function -def compute_native_function_declaration(f: NativeFunction) -> List[str]: - ns = list(f.dispatch.values()) - - rs = [] - # Sometimes a function name shows up multiple times; only generate - # it once! - seen = set() - for n in ns: - if n in seen: - continue - if "legacy::" in n: - continue - seen.add(n) - returns_type = native.returns_type(f.func.returns) - args = native.arguments(f.func) - rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(a.str_with_default() for a in args)});") +def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]: + if isinstance(g, StructuredNativeFunctions): + # only out has dispatch + meta_name = meta.name(g) + rs = [] + seen = set() + out_args = native.arguments(g.out.func) + for k, n in g.out.dispatch.items(): + if n in seen: + continue + if not is_structured_dispatch_key(k): + continue + seen.add(n) + rs.append(f"""\ +struct CAFFE2_API structured_{n} : public at::meta::{meta_name} {{ + void impl({', '.join(a.str_with_default() for a in out_args)}); +}}; +""") + + seen = set() + for f in g.functions(): + returns_type = native.returns_type(f.func.returns) + args = native.arguments(f.func) + for k, n in f.dispatch.items(): + if n in seen: + continue + if is_structured_dispatch_key(k): + continue + seen.add(n) + rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(a.str_with_default() for a in args)});") + + return rs - return rs + else: + f = g + ns = list(f.dispatch.values()) + + rs = [] + # Sometimes a function name shows up multiple times; only generate + # it once! + seen = set() + for n in ns: + if n in seen: + continue + if "legacy::" in n: + continue + seen.add(n) + returns_type = native.returns_type(f.func.returns) + args = native.arguments(f.func) + rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(a.str_with_default() for a in args)});") + + return rs def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str: with native_function_manager(g.out): sig = g.signature() - name = meta.name(sig) - returns_type = meta.returns_type(sig.returns) + name = meta.name(g) args = meta.arguments(sig) - return f"CAFFE2_API {returns_type} {name}({', '.join(map(str, args))});" + args_str = ', '.join(map(str, args)) + parent_class = g.out.structured_inherits + if parent_class is None: + parent_class = "at::impl::MetaBase" + return f"""\ +struct CAFFE2_API {name} : public {parent_class} {{ + void meta({args_str}); +}}; +""" # Generates RegisterBackendSelect.cpp, a series of kernels which provide # specialized computation of dispatch key for operator signatures which cannot @@ -633,12 +820,8 @@ def __call__(self, f: NativeFunction) -> Optional[str]: }} """ elif self.target is Target.REGISTRATION: - if local.use_c10_dispatcher() is UseC10Dispatcher.full: + if local.use_c10_dispatcher().dispatcher_uses_new_style(): return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" - elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: - return f"""m.impl("aten::{f.func.name}", - c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_sig.type()}>( - TORCH_FN({name})));""" else: assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});""" @@ -828,10 +1011,10 @@ def compute_declaration_yaml(f: NativeFunction) -> object: # These sets are used to conveniently test if an argument is a # kwarg-only or out argument - kwarg_only_set = set(a.name for a in f.func.arguments.kwarg_only) + kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) out_arg_set = set(a.name for a in f.func.arguments.out) - sig_group = CppSignatureGroup.from_schema(f.func, method=False) + sig_group = CppSignatureGroup.from_schema(f.func, method=False, fallback_binding=False) cpp_args = sig_group.signature.arguments() arguments = [ compute_cpp_argument_yaml( @@ -849,7 +1032,11 @@ def compute_declaration_yaml(f: NativeFunction) -> object: for a in schema_order_jit_arguments ] - cpp_schema_order_types = [cpp.argument(a).type for a in schema_order_jit_arguments] + cpp_schema_order_types = [ + # NB: method here doesn't matter + cpp.argument(a, method=False).type for a in schema_order_jit_arguments + ] + cpp_returns = cpp.returns_type(f.func.returns) schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" @@ -1084,11 +1271,13 @@ def make_file_manager(install_dir: str) -> FileManager: cuda_fm = make_file_manager(options.install_dir) extra_cuda_headers = '''\ +#include #include #include #include ''' if options.rocm: extra_cuda_headers = '''\ +#include #include #include #include ''' @@ -1125,11 +1314,11 @@ def make_file_manager(install_dir: str) -> FileManager: '', 'DispatchKey': dispatch_key, 'dispatch_definitions': list(concatMap( - RegisterDispatchKey(dispatch_key, Target.DEFINITION, selector), + RegisterDispatchKey(dispatch_key, Target.DEFINITION, selector, rocm=options.rocm), grouped_native_functions )), 'dispatch_registrations': list(concatMap( - RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector), + RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm), grouped_native_functions )), }) @@ -1170,7 +1359,7 @@ def make_file_manager(install_dir: str) -> FileManager: 'aten_ops': list(mapMaybe(compute_aten_op, native_functions)), }) cpu_fm.write('NativeFunctions.h', lambda: { - 'native_function_declarations': list(concatMap(compute_native_function_declaration, native_functions)), + 'native_function_declarations': list(concatMap(compute_native_function_declaration, grouped_native_functions)), }) cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions])) diff --git a/tools/codegen/model.py b/tools/codegen/model.py index f270d0737ade..0a2689860a17 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass -from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable +from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable, Union from enum import Enum import itertools @@ -98,6 +98,12 @@ class NativeFunction: # registrations don't participate in codegen-based selective build! manual_kernel_registration: bool + # Whether or not to skip generating TensorMethod/Functions bindings + # for this kernel. Technically, this doesn't actually skip generating + # the binding; instead, the binding gets generated to __dispatch_{funcname} + # so you can make use of the normal binding if you need it. + manual_cpp_binding: bool + # A mapping of dispatch keys to names of functions implementing # them. In native_functions.yaml, the dispatch entry is optional; in that # case, that is equivalent to having written: @@ -125,6 +131,12 @@ class NativeFunction: # in terms of the out kernel referenced by the string here. structured_delegate: Optional['OperatorName'] + # Only valid for structured kernels. Specifies alternative of what + # to inherit from when defining the meta class for the structured + # operator. This will usually be TensorIteratorBase. This also + # changes the semantics of set_output to call the parent class. + structured_inherits: Optional[str] + # Note [Abstract ATen methods] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # An abstract ATen method is one whose dispatch differs between @@ -182,6 +194,9 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': manual_kernel_registration = e.pop('manual_kernel_registration', False) assert isinstance(manual_kernel_registration, bool), f'not a bool: {manual_kernel_registration}' + manual_cpp_binding = e.pop('manual_cpp_binding', False) + assert isinstance(manual_cpp_binding, bool), f'not a bool: {manual_cpp_binding}' + device_guard = e.pop('device_guard', True) assert isinstance(device_guard, bool), f'not a bool: {device_guard}' @@ -194,6 +209,9 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': if structured_delegate_s is not None: structured_delegate = OperatorName.parse(structured_delegate_s) + structured_inherits = e.pop('structured_inherits', None) + assert structured_inherits is None or isinstance(structured_inherits, str), f'not a str: {structured_inherits}' + python_module = e.pop('python_module', None) assert python_module is None or isinstance(python_module, str), f'not a str: {python_module}' @@ -229,7 +247,9 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': variants=variants, structured=structured, structured_delegate=structured_delegate, + structured_inherits=structured_inherits, manual_kernel_registration=manual_kernel_registration, + manual_cpp_binding=manual_cpp_binding, python_module=python_module, category_override=category_override, dispatch=dispatch, @@ -261,9 +281,11 @@ def __post_init__(self) -> None: if self.structured: assert self.func.kind() == SchemaKind.out, "Put structured field on the out= " \ "variant of a function; did you mean structured_delegate?" + assert self.device_guard, "device_guard: False is not respected by structured kernels" if self.structured_delegate: assert self.func.kind() != SchemaKind.out, "structured_delegate field not allowed " \ "on out= functions; did you mean structured?" + assert self.device_guard, "device_guard: False is not respected by structured kernels" # Technically, with the asserts above, this assert is impossible to # happen assert not (self.structured and self.structured_delegate), \ @@ -386,7 +408,11 @@ class FunctionSchema: returns: Tuple['Return', ...] def schema_order_arguments(self) -> Iterator['Argument']: - return itertools.chain(self.arguments.positional, self.arguments.kwarg_only, self.arguments.out) + return itertools.chain( + self.arguments.flat_positional, + self.arguments.flat_kwarg_only, + self.arguments.out + ) @staticmethod def parse(func: str) -> 'FunctionSchema': @@ -416,7 +442,7 @@ def __post_init__(self) -> None: # This means that all mutable returns should be aliased to a keyword argument # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) # See Note [is_out_fn] - out_and_self = list(self.arguments.out) + [arg for arg in self.arguments.positional if arg.name == "self"] + out_and_self = list(self.arguments.out) + [arg for arg in self.arguments.flat_positional if arg.name == "self"] mutable_returns = [ret for ret in self.returns if ret.annotation is not None and ret.annotation.is_write] for ret in mutable_returns: assert any([ret.annotation == arg.annotation for arg in out_and_self]), \ @@ -466,6 +492,9 @@ def __post_init__(self) -> None: '_foreach_round_', '_foreach_lgamma_', '_foreach_frac_', + '_foreach_reciprocal_', + '_foreach_sigmoid_', + '_foreach_trunc_', '_foreach_addcmul_.Scalar', '_foreach_addcdiv_.Scalar', '_foreach_addcmul_.ScalarList', @@ -884,7 +913,14 @@ class Arguments: out: Tuple[Argument, ...] # these are also kwarg-only @property - def positional(self) -> Sequence[Argument]: + def flat_non_out(self) -> Sequence[Argument]: + ret: List[Argument] = [] + ret.extend(self.flat_positional) + ret.extend(self.flat_kwarg_only) + return ret + + @property + def flat_positional(self) -> Sequence[Argument]: ret: List[Argument] = [] ret.extend(self.pre_self_positional) if self.self_arg is not None: @@ -894,7 +930,7 @@ def positional(self) -> Sequence[Argument]: # NB: doesn't contain out arguments @property - def kwarg_only(self) -> Sequence[Argument]: + def flat_kwarg_only(self) -> Sequence[Argument]: ret: List[Argument] = [] ret.extend(self.pre_tensor_options_kwarg_only) if self.tensor_options is not None: @@ -902,6 +938,31 @@ def kwarg_only(self) -> Sequence[Argument]: ret.extend(self.post_tensor_options_kwarg_only) return ret + @property + def non_out(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: + ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] + ret.extend(self.positional) + ret.extend(self.kwarg_only) + return ret + + @property + def positional(self) -> Sequence[Union[Argument, SelfArgument]]: + ret: List[Union[Argument, SelfArgument]] = [] + ret.extend(self.pre_self_positional) + if self.self_arg is not None: + ret.append(self.self_arg) + ret.extend(self.post_self_positional) + return ret + + @property + def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: + ret: List[Union[Argument, TensorOptionsArguments]] = [] + ret.extend(self.pre_tensor_options_kwarg_only) + if self.tensor_options is not None: + ret.append(self.tensor_options) + ret.extend(self.post_tensor_options_kwarg_only) + return ret + def signature(self) -> 'Arguments': # dataclasses.replace could be used here, but it is less # type safe so for now I've opted to type everything out @@ -968,7 +1029,7 @@ def parse(args: str) -> 'Arguments': Input: 'int x, int y, int z' """ - # We do this in two phases. First we parse into three + # We do this in two phases. First we parse into three # main categories: positional, kwarg_only, out. # Then, we reparse positional and kwarg_only to separate # out the self argument and tensor options arguments. @@ -1041,10 +1102,10 @@ def pred(name: str, ty: Type) -> Callable[[Argument], bool]: def __str__(self) -> str: all_arguments: List[str] = [] - all_arguments.extend(map(str, self.positional)) - if self.kwarg_only or self.out: + all_arguments.extend(map(str, self.flat_positional)) + if self.flat_kwarg_only or self.out: all_arguments.append('*') - all_arguments.extend(map(str, self.kwarg_only)) + all_arguments.extend(map(str, self.flat_kwarg_only)) all_arguments.extend(map(str, self.out)) return ', '.join(all_arguments) diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py new file mode 100755 index 000000000000..2a8d1d731453 --- /dev/null +++ b/tools/fast_nvcc/fast_nvcc.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python3 + +import argparse +import asyncio +import collections +import csv +import hashlib +import itertools +import os +import pathlib +import re +import shlex +import shutil +import subprocess +import sys +import time + + +help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]... + +Run the commands given by nvcc --dryrun, in parallel. + +All flags for this script itself (see the "optional arguments" section +of --help) must be passed before the first "--". Everything after that +first "--" is passed directly to nvcc, with the --dryrun argument added. + +This script only works with the "normal" execution path of nvcc, so for +instance passing --help (after "--") doesn't work since the --help +execution path doesn't compile anything, so adding --dryrun there gives +nothing in stderr. +''' +parser = argparse.ArgumentParser(help_msg) +parser.add_argument( + '--faithful', + action='store_true', + help="don't modify the commands given by nvcc (slower)", +) +parser.add_argument( + '--graph', + metavar='FILE.dot', + help='write Graphviz DOT file with execution graph', +) +parser.add_argument( + '--nvcc', + metavar='PATH', + default='nvcc', + help='path to nvcc (default is just "nvcc")', +) +parser.add_argument( + '--save', + metavar='DIR', + help='copy intermediate files from each command into DIR', +) +parser.add_argument( + '--sequential', + action='store_true', + help='sequence commands instead of using the graph (slower)', +) +parser.add_argument( + '--table', + metavar='FILE.csv', + help='write CSV with times and intermediate file sizes', +) +parser.add_argument( + '--verbose', + metavar='FILE.txt', + help='like nvcc --verbose, but expanded and into a file', +) +default_config = parser.parse_args([]) + + +# docs about temporary directories used by NVCC +url_base = 'https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' +url_vars = f'{url_base}#keeping-intermediate-phase-files' + + +# regex for temporary file names +re_tmp = r'(? '{filename}'") + uniqueified.append(line) + return uniqueified + + +def make_rm_force(commands): + """ + Add --force to all rm commands. + """ + return [f'{c} --force' if c.startswith('rm ') else c for c in commands] + + +def print_verbose_output(*, env, commands, filename): + """ + Human-readably write nvcc --dryrun data to stderr. + """ + padding = len(str(len(commands) - 1)) + with open(filename, 'w') as f: + for name, val in env.items(): + print(f'#{" "*padding}$ {name}={val}', file=f) + for i, command in enumerate(commands): + prefix = f'{str(i).rjust(padding)}$ ' + print(f'#{prefix}{command[0]}', file=f) + for part in command[1:]: + print(f'#{" "*len(prefix)}{part}', file=f) + + +def straight_line_dependencies(commands): + """ + Return a straight-line dependency graph. + """ + return [({i - 1} if i > 0 else set()) for i in range(len(commands))] + + +def files_mentioned(command): + """ + Return fully-qualified names of all tmp files referenced by command. + """ + return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)] + + +def nvcc_data_dependencies(commands): + """ + Return a list of the set of dependencies for each command. + """ + # fatbin needs to be treated specially because while the cicc steps + # do refer to .fatbin.c files, they do so through the + # --include_file_name option, since they're generating files that + # refer to .fatbin.c file(s) that will later be created by the + # fatbinary step; so for most files, we make a data dependency from + # the later step to the earlier step, but for .fatbin.c files, the + # data dependency is sort of flipped, because the steps that use the + # files generated by cicc need to wait for the fatbinary step to + # finish first + tmp_files = {} + fatbins = collections.defaultdict(set) + graph = [] + for i, line in enumerate(commands): + deps = set() + for tmp in files_mentioned(line): + if tmp in tmp_files: + dep = tmp_files[tmp] + deps.add(dep) + if dep in fatbins: + for filename in fatbins[dep]: + if filename in tmp_files: + deps.add(tmp_files[filename]) + if tmp.endswith('.fatbin.c') and not line.startswith('fatbinary'): + fatbins[i].add(tmp) + else: + tmp_files[tmp] = i + if line.startswith('rm ') and not deps: + deps.add(i - 1) + graph.append(deps) + return graph + + +def is_weakly_connected(graph): + """ + Return true iff graph is weakly connected. + """ + neighbors = [set() for _ in graph] + for node, predecessors in enumerate(graph): + for pred in predecessors: + neighbors[pred].add(node) + neighbors[node].add(pred) + # assume nonempty graph + stack = [0] + found = {0} + while stack: + node = stack.pop() + for neighbor in neighbors[node]: + if neighbor not in found: + found.add(neighbor) + stack.append(neighbor) + return len(found) == len(graph) + + +def warn_if_not_weakly_connected(graph): + """ + Warn the user if the execution graph is not weakly connected. + """ + if not is_weakly_connected(graph): + fast_nvcc_warn('execution graph is not (weakly) connected') + + +def print_dot_graph(*, commands, graph, filename): + """ + Print a DOT file displaying short versions of the commands in graph. + """ + def name(k): + return f'"{k} {os.path.basename(commands[k][0])}"' + with open(filename, 'w') as f: + print('digraph {', file=f) + # print all nodes, in case it's disconnected + for i in range(len(graph)): + print(f' {name(i)};', file=f) + for i, deps in enumerate(graph): + for j in deps: + print(f' {name(j)} -> {name(i)};', file=f) + print('}', file=f) + + +async def run_command(command, *, env, deps, gather_data, i, save): + """ + Run the command with the given env after waiting for deps. + """ + for task in deps: + await task + if gather_data: + t1 = time.monotonic() + proc = await asyncio.create_subprocess_shell( + command, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + code = proc.returncode + results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr} + if gather_data: + t2 = time.monotonic() + results['time'] = t2 - t1 + sizes = {} + for tmp_file in files_mentioned(command): + if os.path.exists(tmp_file): + sizes[tmp_file] = os.path.getsize(tmp_file) + else: + sizes[tmp_file] = 0 + results['files'] = sizes + if save: + dest = pathlib.Path(save) / str(i) + dest.mkdir() + for src in map(pathlib.Path, files_mentioned(command)): + if src.exists(): + shutil.copy2(src, dest / (src.name)) + return results + + +async def run_graph(*, env, commands, graph, gather_data, save): + """ + Return outputs/errors (and optionally time/file info) from commands. + """ + tasks = [] + for i, (command, indices) in enumerate(zip(commands, graph)): + deps = {tasks[j] for j in indices} + tasks.append(asyncio.create_task(run_command( + command, + env=env, + deps=deps, + gather_data=gather_data, + i=i, + save=save, + ))) + return [await task for task in tasks] + + +def print_command_outputs(command_results): + """ + Print captured stdout and stderr from commands. + """ + for result in command_results: + sys.stdout.write(result['stdout'].decode('ascii')) + sys.stderr.write(result['stderr'].decode('ascii')) + + +def write_log_csv(command_parts, command_results, *, filename): + """ + Write a CSV file of the times and /tmp file sizes from each command. + """ + tmp_files = [] + for result in command_results: + tmp_files.extend(result['files'].keys()) + with open(filename, 'w', newline='') as csvfile: + fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files)) + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for i, result in enumerate(command_results): + command = f'{i} {os.path.basename(command_parts[i][0])}' + row = {'command': command, 'seconds': result['time']} + writer.writerow({**row, **result['files']}) + + +def exit_code(results): + """ + Aggregate individual exit codes into a single code. + """ + for result in results: + code = result['exit_code'] + if code != 0: + return code + return 0 + + +def fast_nvcc(args, *, config=default_config): + """ + Emulate the result of calling the given nvcc binary with args. + + Should run faster than plain nvcc. + """ + warn_if_windows() + warn_if_tmpdir_flag(args) + dryrun_data = nvcc_dryrun_data(config.nvcc, args) + env = dryrun_data['env'] + warn_if_tmpdir_set(env) + commands = dryrun_data['commands'] + if not config.faithful: + commands = make_rm_force(unique_module_id_files(commands)) + command_parts = list(map(shlex.split, commands)) + if config.verbose: + print_verbose_output( + env=env, + commands=command_parts, + filename=config.verbose, + ) + graph = nvcc_data_dependencies(commands) + warn_if_not_weakly_connected(graph) + if config.graph: + print_dot_graph( + commands=command_parts, + graph=graph, + filename=config.graph, + ) + if config.sequential: + graph = straight_line_dependencies(commands) + results = asyncio.run(run_graph( + env=env, + commands=commands, + graph=graph, + gather_data=bool(config.table), + save=config.save, + )) + print_command_outputs(results) + if config.table: + write_log_csv(command_parts, results, filename=config.table) + return exit_code([dryrun_data] + results) + + +def our_arg(arg): + return arg != '--' + + +if __name__ == '__main__': + argv = sys.argv[1:] + us = list(itertools.takewhile(our_arg, argv)) + them = list(itertools.dropwhile(our_arg, argv)) + sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us))) diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py index f2896fac7f22..267b5a3b221a 100644 --- a/tools/jit/gen_unboxing_wrappers.py +++ b/tools/jit/gen_unboxing_wrappers.py @@ -49,6 +49,7 @@ 'std::string': 'str', 'std::string?': 'str?', 'Scalar': 'Scalar', + 'ScalarList': 'Scalar[]', 'MemoryFormat': 'MemoryFormat', 'MemoryFormat?': 'MemoryFormat?', 'QScheme': 'QScheme', @@ -131,6 +132,7 @@ def jit_type_of(arg): 'Tensor?': 'toOptionalTensor({})', 'Tensor?[]': 'toListOfOptionalTensor({})', 'TensorList': '{}.toTensorVector()', + 'ScalarList': '{}.toScalarVector()', 'bool': '{}.toBool()', 'bool?': '{}.toOptional()', 'double': '{}.toDouble()', diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 617f997a8d76..d2073bec9a27 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -3,13 +3,14 @@ import collections from pprint import pformat -import yaml -import re import argparse -from ..autograd.utils import YamlLoader, CodeTemplate, write, group_declarations_by_op_name, is_tensor_method, is_torch_function -from ..autograd.gen_python_functions import SKIP_PYTHON_BINDINGS, SKIP_PYTHON_BINDINGS_SIGNATURES -from ..autograd.gen_autograd import load_aten_declarations +from tools.codegen.model import * +from tools.codegen.api.python import * +from typing import Sequence, List, Dict + +from ..autograd.utils import CodeTemplate, write +from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads """ This module implements generation of type stubs for PyTorch, @@ -28,60 +29,38 @@ (the latter case should be pretty rare). - We go through automatically bound functions based on the - type information recorded in Declarations.yaml and + type information recorded in native_functions.yaml and generate type hints for them (generate_type_hints) There are a number of type hints which we've special-cased; read gen_pyi for the gory details. """ -# TODO: remove after migrating entire codegen to the new data model. -def should_generate_python_binding(declaration): - name = declaration['name'] - for pattern in SKIP_PYTHON_BINDINGS: - if re.match('^' + pattern + '$', name): - return False - - simple_types = [arg['simple_type'] for arg in declaration['arguments']] - signature = '{}({})'.format(name, ', '.join(simple_types)) - for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: - if pattern == signature: - return False - - return True - - -def get_py_variable_methods(declarations): +def get_py_torch_functions( + python_funcs: Sequence[PythonSignatureNativeFunctionPair], + method: bool = False, +) -> Sequence[PythonSignatureGroup]: """ Get declarations (grouped by name) which should be generated - as methods on Tensor. + as either functions in the "torch" module or methods on Tensor. """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - not declaration.get('python_module') and - is_tensor_method(declaration)) + def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: + return (should_generate_py_binding(python_func.function) and + not python_func.function.python_module and + Variant.function in python_func.function.variants) - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) - - -def get_py_torch_functions(declarations): - """ - Get declarations (grouped by name) which should be generated - as functions in the "torch" module. - """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - not declaration.get('python_module') and - is_torch_function(declaration)) + def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: + return (should_generate_py_binding(python_func.function) and + not python_func.function.python_module and + Variant.method in python_func.function.variants) - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) + should_bind = should_bind_method if method else should_bind_function + return group_overloads([f for f in python_funcs if should_bind(f)]) # TODO: Consider defining some aliases for our Union[...] types, to make # the stubs to read on the human eye. -needed_modules = set() - DEVICE_PARAM = "device: Union[_device, str, None]=None" FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False" @@ -143,91 +122,6 @@ def should_bind(declaration): 'floor_divide', 'floor_divide_', 'floor_divide_out', ] - -def type_to_python(typename, size=None): - """type_to_python(typename: str, size: str) -> str - - Transforms a Declarations.yaml type name into a Python type specification - as used for type hints. - """ - typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *' - - # Disambiguate explicitly sized int/tensor lists from implicitly - # sized ones. These permit non-list inputs too. (IntArrayRef[] and - # TensorList[] are not real types; this is just for convenience.) - if typename in {'IntArrayRef', 'TensorList'} and size is not None: - typename += '[]' - - typename = { - 'Device': 'Device', - 'Generator': 'Generator', - 'IntegerTensor': 'Tensor', - 'Scalar': 'Number', - 'ScalarType': '_dtype', - 'Storage': 'Storage', - 'BoolTensor': 'Tensor', - 'IndexTensor': 'Tensor', - 'Tensor': 'Tensor', - 'MemoryFormat': 'memory_format', - 'IntArrayRef': '_size', - 'IntArrayRef[]': 'Union[_int, _size]', - 'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]', - 'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]', - 'bool': '_bool', - 'double': '_float', - 'int64_t': '_int', - 'accreal': 'Number', - 'real': 'Number', - 'void*': '_int', # data_ptr - 'void': 'None', - 'std::string': 'str', - 'Dimname': 'Union[str, ellipsis, None]', - 'DimnameList': 'Sequence[Union[str, ellipsis, None]]', - 'QScheme': '_qscheme', - 'ArrayRef' : 'Sequence[float]', - 'Stream': 'Stream', - }[typename] - - return typename - - -def arg_to_type_hint(arg): - """arg_to_type_hint(arg) -> str - - This takes one argument in a Declarations and returns a string - representing this argument in a type hint signature. - """ - name = arg['name'] - if name == 'from': # from is a Python keyword... - name += '_' - typename = type_to_python(arg['dynamic_type'], arg.get('size')) - if arg.get('is_nullable'): - typename = 'Optional[' + typename + ']' - if 'default' in arg: - default = arg['default'] - if default == 'nullptr': - default = None - elif default == 'c10::nullopt': - default = None - elif isinstance(default, str) and default.startswith('{') and default.endswith('}'): - if arg['dynamic_type'] == 'Tensor' and default == '{}': - default = None - elif arg['dynamic_type'] == 'Generator' and default == '{}': - default = None - elif arg['dynamic_type'] == 'IntArrayRef': - default = '(' + default[1:-1] + ')' - else: - raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type'])) - elif default == 'MemoryFormat::Contiguous': - default = 'contiguous_format' - elif default == 'QScheme::PER_TENSOR_AFFINE': - default = 'per_tensor_affine' - default = '={}'.format(default) - else: - default = '' - return name + ': ' + typename + default - - binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv', 'matmul', 'floordiv', 'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic @@ -241,7 +135,7 @@ def arg_to_type_hint(arg): all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops -def sig_for_ops(opname): +def sig_for_ops(opname: str) -> List[str]: """sig_for_ops(opname : str) -> List[str] Returns signatures for operator special functions (__add__ etc.)""" @@ -271,146 +165,35 @@ def sig_for_ops(opname): else: raise Exception("unknown op", opname) - -# Copied from 'gen_python_functions.py' -# TODO: consolidate after migrating to the new codegen model in 'tools/codegen'. -def namedtuple_fieldnames(declaration): - returns = declaration['returns'] - if len(returns) <= 1 or all(['field_name' not in x for x in returns]): - return [] - else: - def get_field_name(x): - # See Note [field_name versus name] - if 'field_name' not in x: - # When building on Windows, `PyStructSequence_UnnamedField` could not be - # resolved by the linker for some reason, which cause error in building: - # - # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol - # PyStructSequence_UnnamedField - # - # Thus, at this point in time, we do not support unnamed - # fields in namedtuple; you must either name all fields, - # or none of them. - raise ValueError("Unnamed field is not supported by codegen") - else: - return x['field_name'] - return [get_field_name(x) for x in returns] - - -def generate_type_hints(fname, decls, namedtuples, is_tensor=False): - """generate_type_hints(fname, decls, is_tensor=False) - - Generates type hints for the declarations pertaining to the function - :attr:`fname`. attr:`decls` are the declarations from the parsed - Declarations.yaml. - :attr:`namedtuples` is a dictionary for accumulating NamedTuple definitions. - The :attr:`is_tensor` flag indicates whether we are parsing - members of the Tensor class (true) or functions in the - `torch` namespace (default, false). - - This function currently encodes quite a bit about the semantics of - the translation C++ -> Python. - """ - if fname in blocklist: - return [] - +def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: type_hints = [] - dnames = ([d['name'] for d in decls]) - has_out = fname + '_out' in dnames - - if has_out: - decls = [d for d in decls if d['name'] != fname + '_out'] - - for decl in decls: - render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument - python_args = [] - - has_tensor_options = 'TensorOptions' in (a['dynamic_type'] for a in decl['arguments']) - - for a in decl['arguments']: - if a['dynamic_type'] != 'TensorOptions': - if a.get('kwarg_only', False) and render_kw_only_separator: - python_args.append('*') - render_kw_only_separator = False - try: - python_args.append(arg_to_type_hint(a)) - except Exception: - print("Error while processing function {}".format(fname)) - raise - - if 'self: Tensor' in python_args: - self_index = python_args.index('self: Tensor') - python_args.remove('self: Tensor') - if is_tensor: - python_args = ['self'] + python_args - else: - python_args.insert(self_index, 'input: Tensor') - else: - if is_tensor: - raise Exception("method without self is unexpected") - - if has_out: - if render_kw_only_separator: - python_args.append('*') - render_kw_only_separator = False - python_args.append('out: Optional[Tensor]=None') - - if has_tensor_options: - if render_kw_only_separator: - python_args.append('*') - render_kw_only_separator = False - python_args += ["dtype: _dtype=None", - "layout: _layout=strided", - "device: Union[_device, str, None]=None", - "requires_grad:_bool=False"] - - python_args_s = ', '.join(python_args) - python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']] - field_names = namedtuple_fieldnames(decl) - - if field_names: - namedtuple_name = '_'.join(['namedtuple'] + field_names) - tuple_args = ['("{}", {})'.format(name, typ) for name, typ in zip(field_names, python_returns)] - namedtuple_def = 'NamedTuple("{}", [{}])'.format(namedtuple_name, ', '.join(tuple_args)) - if namedtuple_name in namedtuples: - assert namedtuples[namedtuple_name] == namedtuple_def - else: - namedtuples[namedtuple_name] = namedtuple_def - python_returns_s = namedtuple_name - elif len(python_returns) > 1: - python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']' - elif len(python_returns) == 1: - python_returns_s = python_returns[0] - else: - python_returns_s = 'None' - type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) - numargs = len(decl['arguments']) - vararg_pos = int(is_tensor) - have_vararg_version = (numargs > vararg_pos and - decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and - (numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and - (not is_tensor or decl['arguments'][0]['name'] == 'self')) + # Some deprecated ops that are on the blocklist are still included in pyi + if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: + return type_hints + # deprecated signatures have separate entries for their functional and out variants + # (as opposed to the native ops, which fuse the two into a single signature). + # generate the functional variant here, if an out variant exists. + if sig_group.signature.deprecated and sig_group.outplace is not None: + type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) type_hints.append(type_hint) - if have_vararg_version: - # Two things come into play here: PyTorch has the "magic" that if the first and only positional argument - # is an IntArrayRef, it will be used as a vararg variant. - # The following outputs the vararg variant, the "pass a list variant" is output above. - # The other thing is that in Python, the varargs are annotated with the element type, not the list type. - typelist = decl['arguments'][vararg_pos]['dynamic_type'] - vararg_type = '_int' - # replace first argument and eliminate '*' if present - python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] + - ': ' + vararg_type] + python_args[vararg_pos + 2:]) - python_args_s = ', '.join(python_args) - type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) - type_hints.append(type_hint) + # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument + # Generates the out variant if one exists. Otherwise, generate the functional variant + type_hint = sig_group.signature.signature_str_pyi( + skip_outputs=sig_group.outplace is None) + type_hints.append(type_hint) + + # Some operators also additionally have a vararg variant of their signature + type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( + skip_outputs=sig_group.outplace is None) + if type_hint_vararg: + type_hints.append(type_hint_vararg) return type_hints -def gen_nn_functional(out): +def gen_nn_functional(out: str) -> None: # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered # through an `_add_docstr` call imports = [ @@ -475,10 +258,10 @@ def gen_nn_functional(out): stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in')) write(out, 'torch/_C/_nn.pyi', stubs, env) -def gen_nn_pyi(out): +def gen_nn_pyi(out: str) -> None: gen_nn_functional(out) -def gen_pyi(declarations_path, out): +def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None: """gen_pyi() This function generates a pyi file for torch. @@ -491,16 +274,13 @@ def gen_pyi(declarations_path, out): # checking. If you are update this, consider if your change # also needs to update the other file. - # Load information from YAML - declarations = load_aten_declarations(declarations_path) - # Dictionary for NamedTuple definitions - namedtuples = {} + namedtuples: Dict[str, str] = {} # Generate type signatures for top-level functions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - unsorted_function_hints = collections.defaultdict(list) + unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_function_hints.update({ 'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'], 'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'], @@ -560,21 +340,20 @@ def gen_pyi(declarations_path, out): ' other: Union[Tensor, Number],' ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop)) - function_declarations = get_py_torch_functions(declarations) - for name in sorted(function_declarations.keys()): - unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name], namedtuples) - - # Generate type signatures for deprecated functions - - # TODO: Maybe we shouldn't generate type hints for deprecated - # functions :) However, examples like those addcdiv rely on these. - with open('tools/autograd/deprecated.yaml', 'r') as f: - deprecated = yaml.load(f, Loader=YamlLoader) - for d in deprecated: - name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups() - sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')] - sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig] - unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig))) + function_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=False, pyi=True) + sig_groups = get_py_torch_functions(function_signatures) + for group in sorted(sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_function_hints[name] += generate_type_hints(group) + + named_tuple = group.signature.returns.named_tuple_pyi() + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def function_hints = [] for name, hints in sorted(unsorted_function_hints.items()): @@ -585,26 +364,26 @@ def gen_pyi(declarations_path, out): # Generate type signatures for Tensor methods # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - unsorted_tensor_method_hints = collections.defaultdict(list) + unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_tensor_method_hints.update({ 'size': ['def size(self) -> Size: ...', 'def size(self, _int) -> _int: ...'], 'stride': ['def stride(self) -> Tuple[_int]: ...', 'def stride(self, _int) -> _int: ...'], - 'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'. - format(type_to_python('IntArrayRef'), FACTORY_PARAMS)], + 'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'. + format(FACTORY_PARAMS)], 'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new 'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM), 'def new(self, storage: Storage) -> Tensor: ...', 'def new(self, other: Tensor) -> Tensor: ...', - 'def new(self, size: {}, *, {}) -> Tensor: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM), + 'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM), ], '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM), 'def __init__(self, storage: Storage) -> None: ...', 'def __init__(self, other: Tensor) -> None: ...', - 'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM), + 'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM), ], 'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], # clamp has no default values in the Declarations @@ -679,10 +458,23 @@ def gen_pyi(declarations_path, out): for name in simple_conversions: unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name)) - tensor_method_declarations = get_py_variable_methods(declarations) - for name in sorted(tensor_method_declarations.keys()): - unsorted_tensor_method_hints[name] += \ - generate_type_hints(name, tensor_method_declarations[name], namedtuples, is_tensor=True) + # pyi tensor methods don't currently include deprecated signatures for some reason + # TODO: we should probably add them in + tensor_method_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True) + tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) + + for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_tensor_method_hints[name] += generate_type_hints(group) + + named_tuple = group.signature.returns.named_tuple_pyi() + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def for op in all_ops: name = '__{}__'.format(op) @@ -764,17 +556,20 @@ def gen_pyi(declarations_path, out): gen_nn_pyi(out) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Generate type stubs for PyTorch') - parser.add_argument('--declarations-path', metavar='DECL', - default='torch/share/ATen/Declarations.yaml', - help='path to Declarations.yaml') + parser.add_argument('--native-functions-path', metavar='NATIVE', + default='aten/src/ATen/native/native_functions.yaml', + help='path to native_functions.yaml') + parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED', + default='tools/autograd/deprecated.yaml', + help='path to deprecated.yaml') parser.add_argument('--out', metavar='OUT', default='.', help='path to output directory') args = parser.parse_args() - gen_pyi(args.declarations_path, args.out) + gen_pyi(args.native_functions_path, args.deprecated_functions_path, args.out) if __name__ == '__main__': diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index bcc847e825ad..9b1d6fd4a55f 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -234,9 +234,9 @@ add_custom_command( "${TORCH_SRC_DIR}/nn/functional.pyi" COMMAND "${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi - --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" + --native-functions-path "aten/src/ATen/native/native_functions.yaml" + --deprecated-functions-path "tools/autograd/deprecated.yaml" DEPENDS - "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" "${TORCH_SRC_DIR}/_C/__init__.pyi.in" "${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in" "${TORCH_SRC_DIR}/nn/functional.pyi.in" diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 1360ef079725..1afd8e6c73d7 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -1,6 +1,6 @@ # ${generated_comment} -from torch import Tensor, Generator, strided, memory_format, contiguous_format +from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar from torch._six import inf diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index cbb5b2452e21..2a31552068a1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -165,7 +165,10 @@ def wait(fut: Future) -> Any: ... def _collect_all(futures: List[Future]) -> Future: ... def unify_type_list(types: List[JitType]) -> JitType: ... -def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ... +def _freeze_module(module: ScriptModule, + preserved_attrs: List[str] = [], + freeze_interfaces: _bool = True, + preserveParameters: _bool = True) -> ScriptModule: ... def _is_tracing() -> _bool: ... def _jit_init() -> _bool: ... def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... @@ -217,6 +220,8 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ... # Defined in torch/csrc/jit/python/script_init.cpp ResolutionCallback = Callable[[str], Callable[..., Any]] +# Defined in torch/csrc/jit/python/script_init.cpp +# and torch/csrc/jit/python/init.cpp def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ... def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... @@ -246,6 +251,55 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb def _create_module_with_type(ty: JitType) -> ScriptModule: ... def _run_emit_module_hook(m: ScriptModule): ... def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ... + +def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... +def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ... +def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ... +def _jit_pass_fixup_onnx_loop_node_inputs(n: Node) -> None: ... +def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ... +def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... +def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... +def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ... +def _jit_pass_fuse_addmm(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... +def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ... +def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... +def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... +def _jit_pass_onnx_unpack_quantized_weights( + graph: Graph, + paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_quantization_insert_permutes( + graph: Graph, + paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ... +def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ... +def _jit_pass_erase_number_types(graph: Graph) -> None: ... +def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ... +def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ... +def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ... +def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... +def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... +def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ... +def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... +def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... +def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ... +def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... +def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ... +def _jit_pass_onnx_block( + old_block: Block, + new_block: Block, + operator_export_type: _onnx.OperatorExportTypes, + env: Dict[Value, Value] +) -> None: ... +def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ... + def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ... def _jit_script_compile_overload( qualname: str, @@ -281,8 +335,18 @@ def import_ir_module_from_buffer( extra_files: Dict[str, Any] ) -> ScriptModule: ... +def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ... +def _check_onnx_proto(proto: str) -> None: ... +def _propagate_and_assign_input_shapes( + graph: Graph, + inputs: Tuple[Tensor, ...], + with_grad: _bool, + propagate: _bool +) -> Graph: ... + # Defined in torch/torch/csrc/jit/ir/ir.h class Graph: + def eraseInput(self, i: _int) -> None: ... ... # Defined in torch/csrc/jit/ir/ir.h @@ -366,8 +430,8 @@ class ScriptFunction: def qualified_name(self) -> str: ... class ScriptMethod: + graph: Graph ... - class ModuleDict: def __init__(self, mod: ScriptModule) -> None: ... def items(self) -> List[Tuple[str, Any]]: ... @@ -378,6 +442,10 @@ class ParameterDict: class BufferDict: def __init__(self, mod: ScriptModule) -> None: ... +# Defined in torch/csrc/jit/api/module.h +class Module: + ... + # Defined in torch/csrc/Module.cpp def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension def _autograd_init() -> _bool: ... # THPAutograd_initExtension @@ -667,6 +735,10 @@ class _CudaEventBase: def synchronize(self) -> None: ... def ipc_handle(self) -> bytes: ... +# Defined in torch/csrc/cuda/Graph.cpp +class _CudaGraphBase: + ... + # Defined in torch/csrc/DataLoader.cpp def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs diff --git a/torch/_C/_onnx.pyi b/torch/_C/_onnx.pyi index 51f16566ce6c..7ab3cd9c567d 100644 --- a/torch/_C/_onnx.pyi +++ b/torch/_C/_onnx.pyi @@ -29,6 +29,7 @@ class OperatorExportTypes(Enum): ONNX_ATEN = ... ONNX_ATEN_FALLBACK = ... RAW = ... + ONNX_FALLTHROUGH = ... class TrainingMode(Enum): EVAL = ... diff --git a/torch/__init__.py b/torch/__init__.py index 403c192b47e9..30c328c1da6f 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -359,11 +359,13 @@ def set_deterministic(d): * :class:`torch.nn.FractionalMaxPool2d` when called on a CUDA tensor that requires grad * :class:`torch.nn.FractionalMaxPool3d` when called on a CUDA tensor that requires grad * :func:`torch.nn.functional.interpolate` when called on a CUDA tensor that requires grad - and one of the following modes is used: - - `linear` - - `bilinear` - - `bicubic` - - `trilinear` + and one of the following modes is used: + + - `linear` + - `bilinear` + - `bicubic` + - `trilinear` + * :class:`torch.nn.ReflectionPad1d` when called on a CUDA tensor that requires grad * :class:`torch.nn.ReflectionPad2d` when called on a CUDA tensor that requires grad * :class:`torch.nn.ReplicationPad1d` when called on a CUDA tensor that requires grad diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index ec0ad81dced0..02b666493f9a 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -262,7 +262,7 @@ def _symeig_backward(D_grad, U_grad, A, D, U, largest): class LOBPCGAutogradFunction(torch.autograd.Function): @staticmethod - def forward(ctx, + def forward(ctx, # type: ignore[override] A: Tensor, k: Optional[int] = None, B: Optional[Tensor] = None, @@ -606,7 +606,7 @@ def _lobpcg(A: Tensor, bparams['ortho_use_drop'] = bparams.get('ortho_use_drop', False) if not torch.jit.is_scripting(): - LOBPCG.call_tracker = LOBPCG_call_tracker + LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore if len(A.shape) > 2: N = int(torch.prod(torch.tensor(A.shape[:-2]))) @@ -628,7 +628,7 @@ def _lobpcg(A: Tensor, bXret[i] = worker.X[:, :k] if not torch.jit.is_scripting(): - LOBPCG.call_tracker = LOBPCG_call_tracker_orig + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) @@ -640,7 +640,7 @@ def _lobpcg(A: Tensor, worker.run() if not torch.jit.is_scripting(): - LOBPCG.call_tracker = LOBPCG_call_tracker_orig + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore return worker.E[:k], worker.X[:, :k] diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 08c6cbc56ac6..5945713934ba 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -1,12 +1,12 @@ import math import torch from torch._six import inf -from typing import Union, Optional +from typing import Optional class __PrinterOptions(object): precision: int = 4 - threshold: Union[str, float] = 1000 + threshold: float = 1000 edgeitems: int = 3 linewidth: int = 80 sci_mode: Optional[bool] = None diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 56aec4668b0d..d9f7e8018264 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4638,7 +4638,7 @@ def merge_dicts(*dicts): add_docstr(torch.lu_solve, r""" -lu_solve(input, LU_data, LU_pivots, *, out=None) -> Tensor +lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted LU factorization of A from :meth:`torch.lu`. @@ -6786,7 +6786,7 @@ def merge_dicts(*dicts): The shape of the tensor is defined by the variable argument :attr:`size`. -.. note: +.. note:: With the global dtype default (``torch.float32``), this function returns a tensor with dtype ``torch.int64``. diff --git a/torch/_utils.py b/torch/_utils.py index 6336e2d937d7..fbee17167b56 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -7,6 +7,7 @@ import traceback + def _type(self, dtype=None, non_blocking=False, **kwargs): """Returns the type if `dtype` is not provided, else casts this object to the specified type. @@ -491,3 +492,12 @@ def _get_device_index(device, optional=False, allow_cpu=False) -> int: raise ValueError('Expected a torch.device with a specified index ' 'or an integer, but got:{}'.format(device)) return device_idx + + +def _handle_complex(tensor): + """ + Returns a real view of a tensor if complex dtype else just the tensor + need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule + """ + return torch.view_as_real(tensor) if not isinstance(tensor, + torch.nn.UninitializedParameter) and tensor.is_complex() else tensor diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 71537c562013..380b24edfaab 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -257,3 +257,5 @@ def variable(*args, **kwargs): if kineto_available(): from torch._C._autograd import (ProfilerResult, KinetoEvent, _prepare_profiler, _enable_profiler, _disable_profiler) + +from . import profiler diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b23ab81ada93..3795b6e4f914 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -646,6 +646,7 @@ bool THCPComplexFloatStorage_init(PyObject *module); void THCPStream_init(PyObject *module); void THCPEvent_init(PyObject *module); +void THCPGraph_init(PyObject *module); #ifdef USE_CUDA PyMethodDef* THCPModule_methods(); @@ -786,6 +787,7 @@ PyObject* initModule() { THCPStream_init(module); THCPEvent_init(module); + THCPGraph_init(module); #endif auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) { diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index cb7ae7bb8424..ed08e541661b 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -153,15 +153,15 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional &grads, const Te const Tensor& A, const Tensor& Q, const Tensor& R) -> Tensor { - // For square and deep (tall) case we refer - // Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear - // Algebra Functions with Application in Optimum Experimental Design - // (Extended Version) The derivative for the QR decomposition is adapted - // from Eq. 42 of the above reference. - - // Compute R (R')^{T} + // For square and deep (tall) case we refer: + // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra. + // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition) + // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks. + // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html + + // Compute R grad_R^H Tensor R_term; if (grad_R.defined()) { - R_term = at::matmul(R, grad_R.transpose(-2, -1)); + R_term = at::matmul(R, grad_R.conj().transpose(-2, -1)); } else { // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - // Compute Q^{T} Q' + // Compute grad_Q^H Q Tensor Q_term; if (grad_Q.defined()) { - Q_term = at::matmul(Q.transpose(-2, -1), grad_Q); + Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q); } else { // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - // We want to compute: (rhs_solve_1 . R^{-T}) - // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T} + Tensor M = R_term - Q_term; + + // Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity + Tensor M_tril = at::tril(M); + M = M_tril + M_tril.conj().transpose(-2, -1); + M.diagonal(0, -2, -1).mul_(0.5); + + Tensor rhs_term; + if (grad_Q.defined()) { + rhs_term = grad_Q + at::matmul(Q, M); + } else { + rhs_term = at::matmul(Q, M); + } + + // We want to compute: (rhs_term @ R^{-H}) + // Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H // Since R is upper triangular, we can do this using - // triangular_solve(rhs_solve_1^{T}, R)^{T} - auto rhs_solve_1 = - R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1); - rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1); - Tensor solve_soln_1; - std::tie(solve_soln_1, std::ignore) = at::triangular_solve( - rhs_solve_1.transpose(-2, -1), + // triangular_solve(rhs_term^H, R)^H + Tensor grad_A; + std::tie(grad_A, std::ignore) = at::triangular_solve( + rhs_term.conj().transpose(-2, -1), R, /*upper=*/true, /*transpose=*/false, /*unitriangular=*/false); - Tensor grad_A; - if (grad_R.defined()) { - grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R); - } else { - grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1)); - } - // Successive computations involve computation of QQ^{T} which is identity when A is square - if (A.size(-1) != A.size(-2)) { - Tensor rhs_solve_2; - // We use the same trick from above for this computation - if (grad_Q.defined()) { - rhs_solve_2 = grad_Q - at::matmul(Q, Q_term); - } else { - rhs_solve_2 = -at::matmul(Q, Q_term); - } - Tensor solve_soln_2; - std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R, - /*upper=*/true, /*transpose=*/false, - /*unitriangular=*/false); - grad_A.add_(solve_soln_2.transpose(-2, -1)); - } - return grad_A; + return grad_A.conj().transpose(-2, -1); }; auto m = self.size(-2); @@ -2087,7 +2078,7 @@ Tensor qr_backward(const std::vector &grads, const Te // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y]. // To obtain grad_X we reuse the gradient formula from the square case. // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U), - // where grad_Q_prime = grad_Q + Y @ grad_V.T + // where grad_Q_prime = grad_Q + Y @ grad_V^H // and grad_Y = Q @ grad_V. // Then concatenate grads to get grad_A = [grad_X | grad_Y]. @@ -2099,8 +2090,8 @@ Tensor qr_backward(const std::vector &grads, const Te grad_V = grad_R.narrow(-1, m, n - m); // reuse grad_R to store grad_U grad_R = grad_R.narrow(-1, 0, m); - // grad_Q_prime starts with the value of Y @ grad_V.T - grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1)); + // grad_Q_prime starts with the value of Y @ grad_V^H + grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1)); } else { // when grad_R is not defined then grad_V and grad_Q_prime // get initialized with zeros diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index a4b89ee92639..e67815e5609a 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -134,7 +134,7 @@ template inline variable_list flatten_tensor_args(Args&&... ar } // See NOTE [ Autograd View Variables ] for details. -inline Tensor as_view(const Tensor & base, Tensor tensor, bool is_differentiable, +inline Tensor as_view(const Tensor & base, const Tensor& tensor, bool is_differentiable, c10::optional> view_func=c10::nullopt, CreationMeta creation_meta=CreationMeta::DEFAULT) { auto base_var = Variable(base); @@ -194,16 +194,16 @@ inline Tensor as_view(const Tensor & base, Tensor tensor, bool is_differentiable base_var = base_var._base(); } if (is_differentiable) { - return make_variable_differentiable_view(std::move(base_var), std::move(tensor), creation_meta, std::move(view_func)); + return make_variable_differentiable_view(std::move(base_var), tensor, creation_meta, std::move(view_func)); } else { TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, "Non-differentiable views must have creation_meta=CreationMeta::DEFAULT"); - return make_variable_non_differentiable_view(std::move(base_var), std::move(tensor)); + return make_variable_non_differentiable_view(std::move(base_var), tensor); } } // See NOTE [ Autograd View Variables ] for details. -inline std::vector as_view(const Tensor & base, std::vector tensors, bool is_differentiable, +inline std::vector as_view(const Tensor & base, std::vector& tensors, bool is_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) { auto base_var = Variable(base); if (base_var.is_view()) { @@ -211,11 +211,11 @@ inline std::vector as_view(const Tensor & base, std::vector tens } for(Tensor &tensor : tensors) { if (is_differentiable) { - tensor = make_variable_differentiable_view(base_var, std::move(tensor), creation_meta); + tensor = make_variable_differentiable_view(base_var, tensor, creation_meta); } else { TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, "Non-differentiable views must have creation_meta=CreationMeta::DEFAULT"); - tensor = make_variable_non_differentiable_view(base_var, std::move(tensor)); + tensor = make_variable_non_differentiable_view(base_var, tensor); } } return tensors; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 09dc048f214b..44171e1a3b1b 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -133,26 +133,33 @@ struct TORCH_API Node : std::enable_shared_from_this { /// Evaluates the function on the given inputs and returns the result of the /// function call. variable_list operator()(variable_list&& inputs) { - // Using RecordFunction to trogger observers in the backward pass - at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION); - if (guard.isActive()) { - // Using sequence number and thread id to correlate with - // the forward pass function - guard.setForwardThreadId(thread_id_); - if (guard.needsInputs()) { - guard.before( - name(), - std::vector(inputs.begin(), inputs.end()), - sequence_nr()); - } else { - guard.before(name(), sequence_nr()); - } - } // In the first iteration of named tensors, autograd ignores names and // operates on unnamed tensors. In the long term, autograd should // probably operate with names. at::NoNamesGuard no_names_guard; - return apply(std::move(inputs)); + + bool pre_sampled = false; + if (at::shouldRunRecordFunction(&pre_sampled)) { + // Using RecordFunction to trogger observers in the backward pass + at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION, pre_sampled); + if (guard.isActive()) { + // Using sequence number and thread id to correlate with + // the forward pass function + guard.setForwardThreadId(thread_id_); + if (guard.needsInputs()) { + guard.before( + name(), + std::vector(inputs.begin(), inputs.end()), + sequence_nr()); + } else { + guard.before(name(), sequence_nr()); + } + } + // keeping stack guard object alive during the call + return apply(std::move(inputs)); + } else { + return apply(std::move(inputs)); + } } // Graph Connectivity API diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 78336ded0d88..488b7be9bd8a 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -173,8 +173,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { }); m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { auto cb = at::RecordFunctionCallback( - [](const at::RecordFunction&) {}, - [](const at::RecordFunction&) {}) + [](const at::RecordFunction&) { return nullptr; }, + [](const at::RecordFunction&, at::ObserverContext*) {}) .needsInputs(true) .samplingProb(sampling_prob); if (is_global) { diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 7c91e76490a1..ac6ef84104f3 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -242,6 +242,7 @@ void prepareProfiler( if (!libkineto::api().isProfilerRegistered()) { libkineto_init(); + libkineto::api().suppressLogMessages(); } if (!libkineto::api().isProfilerInitialized()) { diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 88cf22321865..eb52aec8920d 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -417,7 +417,7 @@ void pushProfilingCallbacksLegacy() { [](const at::RecordFunction& fn) { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { - return; + return nullptr; } bool record_cuda = state_ptr->config().state == ProfilerState::CUDA; @@ -432,8 +432,10 @@ void pushProfilingCallbacksLegacy() { } else { state_ptr->pushRange(fn, record_cuda, msg); } + + return nullptr; }, - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn, at::ObserverContext*) { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { return; diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp index 633d0f177295..da8cd22fbbc9 100644 --- a/torch/csrc/autograd/record_function_ops.cpp +++ b/torch/csrc/autograd/record_function_ops.cpp @@ -65,10 +65,10 @@ c10::intrusive_ptr _call_end_callbacks_on_fut( } // Internal only, do not use directly, use Python's record_function() -static auto registry = - RegisterOperators() - .op("profiler::_record_function_enter", &record_function_enter) - .op("profiler::_record_function_exit", &record_function_exit); +TORCH_LIBRARY_FRAGMENT(profiler, m) { + m.def("_record_function_enter", &record_function_enter); + m.def("_record_function_exit", &record_function_exit); +} // Needed to register JIT operator in operator registry below c10::AliasAnalysisKind aliasAnalysisFromSchema() { diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 352e315de7ad..cb8a763f246b 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -449,17 +449,28 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta { // Differentiable view. Track history with DifferentiableViewMeta. inline Variable make_variable_differentiable_view( Variable base, - at::Tensor data, + const at::Tensor& data, CreationMeta creation_meta, c10::optional> view_func = c10::nullopt) { if (data.defined()) { - auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( - /*version_counter=*/0, - /*allow_tensor_metadata_change=*/true); - data_impl_copy->set_autograd_meta(std::make_unique( - data_impl_copy.get(), std::move(base), std::move(view_func), - creation_meta)); - return Variable(data_impl_copy); + // If we already did a TensorImpl allocation for data, just reuse it. + // Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as input), + // we have to use shallow_copy_and_detach to create a new TensorImpl to avoid + // moving leaf node into graph interior. This guarantees only 1 TensorImpl + // allocation happens in view ops. + if (data.getIntrusivePtr().unique() && data.getIntrusivePtr()->unique_version()) { + at::TensorImpl* data_impl = data.unsafeGetTensorImpl(); + data_impl->set_autograd_meta(std::make_unique( + data_impl, std::move(base), std::move(view_func), creation_meta)); + return data; + } else { + c10::intrusive_ptr data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( + /*version_counter=*/0, + /*allow_tensor_metadata_change=*/true); + data_impl_copy->set_autograd_meta(std::make_unique( + data_impl_copy.get(), std::move(base), std::move(view_func), creation_meta)); + return Variable(data_impl_copy); + } } return Variable(); } @@ -468,9 +479,12 @@ inline Variable make_variable_differentiable_view( // Non-differentiable view. Just share version counter. inline Variable make_variable_non_differentiable_view( Variable base, - at::Tensor data, + const at::Tensor& data, bool allow_tensor_metadata_change = true) { if (data.defined()) { + // Currently all of non-differentiable view ops(detach/_indices/_values) + // share the same TensorImpl as their base Tensor. Thus a new TensorImpl + // allocation here is required. auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( /*version_counter=*/impl::version_counter(base), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp new file mode 100644 index 000000000000..b258f00bcf90 --- /dev/null +++ b/torch/csrc/cuda/Graph.cpp @@ -0,0 +1,46 @@ +#include + +#include + +#include +#include + +#include + +// Cargo culted partially from csrc/distributed/c10d/init.cpp +// and partially from csrc/cuda/Stream.cpp. +// THCPStream_init is also declared at global scope. + +// Because THCPGraph_init is forward declared in the only consumer (csrc/Module.cpp) +// I don't think we need a Graph.h. + +template +using shared_ptr_class_ = py::class_>; + +void THCPGraph_init(PyObject *module) { + // Pybind11 patch notes say "py::module_" is more up-to-date syntax, + // but CI linter and some builds prefer "module". + auto torch_C_m = py::handle(module).cast(); + + shared_ptr_class_<::at::cuda::CUDAGraph>(module, "_CudaGraphBase") + .def(py::init<>()) + .def("capture_begin", + &::at::cuda::CUDAGraph::capture_begin, + py::call_guard(), + R"(``capture_begin`` begins Cuda graph capture on the current stream.)") + .def("capture_end", + &::at::cuda::CUDAGraph::capture_end, + py::call_guard(), + R"(``capture_end`` ends Cuda graph capture on the current stream. + After ``capture_end``, ``replay`` may be called on this instance.)") + .def("replay", + &::at::cuda::CUDAGraph::replay, + py::call_guard(), + R"(``replay`` replays the Cuda graph captured by this instance.)") + // reset is called in __del__ on the Python side + // (see class Graph in torch/cuda/streams.py for reasons and caveats) + .def("reset", + &::at::cuda::CUDAGraph::reset, + py::call_guard(), + R"(``reset`` deletes the graph currently held by this instance.)"); +} diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 54fc33e54424..0a7daa3a5b94 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1127,24 +1127,23 @@ that adds a prefix to each key inserted to the store. >>> ddp_model._egister_comm_hook(state = None, hook = allreduce) .. warning :: - ``get_future`` API supports only NCCL backend and single-process single-device mode. + ``get_future`` API supports only NCCL backend. The ``torch._C.Future`` object returned by this API can be used in - ``DistributedDataParallel.register_comm_hook``, but it is subject to some subtle - differences compared to ``torch.futures.Future`` due to compromises made for performance - reasons. + ``DistributedDataParallel.register_comm_hook``, and adds some CUDA-specific + features on top of ``torch.futures.Future``. In the example above, ``allreduce`` work will be done on GPU using NCCL backend, ``fut.wait()`` will return after synchronizing the appropriate NCCL streams - with PyTorch's default device streams to ensure we can have asynchronous CUDA + with PyTorch's current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that - ``FutureNCCL`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. + ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. In addition, if a callback function was added by ``fut.then()``, it will wait until ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. - ``fut.then()`` will return another ``FutureNCCL`` that holds the return value of the + ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the callback and a ``CUDAEvent`` that recorded the callback stream. - Note that ``fut.done()`` returns if the enire operation is completed on the GPU. + Note that ``fut.done()`` returns only whether the operation has been enqueued on the GPU. )"); module.def( diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h index 1d0ea9bce2c8..96b632b6b111 100644 --- a/torch/csrc/jit/api/method.h +++ b/torch/csrc/jit/api/method.h @@ -31,6 +31,15 @@ struct TORCH_API Method { std::vector stack, const Kwargs& kwargs = Kwargs()); + // Run method async. Invocation on this function would invokes a JIT + // interpreter that executes ops inline, one by one, on caller's thread. A + // model can utilize async op, i.e. `fork`, to launch an asynchronous task + // which will be launched on provided `taskLauncher`. + c10::intrusive_ptr run_async( + std::vector stack, + const Kwargs& kwargs = Kwargs(), + TaskLauncher taskLauncher = at::launch); + std::shared_ptr graph() const { return function_->graph(); } diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 04eafc3d0f5d..d74905b5d3f0 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -118,6 +118,17 @@ IValue Method::operator()(std::vector stack, const Kwargs& kwargs) { return (*function_)(std::move(stack), kwargs); } +c10::intrusive_ptr Method::run_async( + std::vector stack, + const Kwargs& kwargs, + TaskLauncher taskLauncher) { + stack.insert(stack.begin(), owner()._ivalue()); + RECORD_TORCHSCRIPT_FUNCTION(name(), stack); + + function_->getSchema().checkAndNormalizeInputs(stack, kwargs); + return function_->runAsync(stack, std::move(taskLauncher)); +} + void Module::clone_method( const Module& orig, const Function& method, diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 8bc3ba3b4c6f..e3efd924efb6 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -36,9 +36,9 @@ void runFusionGroup(const Node* fusion_node, Stack& stack) { void fuseGraph(std::shared_ptr& graph) { TORCH_CHECK( - getFuserInterface()->fn_fuse_graph != nullptr, + getFuserInterface()->fn_fuse_graph_ != nullptr, "Running the CUDA fuser requires a CUDA build."); - getFuserInterface()->fn_fuse_graph(graph); + getFuserInterface()->fn_fuse_graph_(graph); } bool canFuseNode(const Node* node) { diff --git a/torch/csrc/jit/codegen/cuda/interface.h b/torch/csrc/jit/codegen/cuda/interface.h index 7c156b1dc7c9..00d94a9f12e0 100644 --- a/torch/csrc/jit/codegen/cuda/interface.h +++ b/torch/csrc/jit/codegen/cuda/interface.h @@ -2,6 +2,7 @@ #include #include +#include /* * This file contains APIs for cuda fuser; @@ -22,7 +23,7 @@ TORCH_API std::atomic& getCudaFusionGuardMode(); struct CudaFuserInterface { void (*fn_compile_n_)(Node*) = nullptr; void (*fn_run_n_s_)(const Node*, Stack&) = nullptr; - void (*fn_fuse_graph)(std::shared_ptr&) = nullptr; + void (*fn_fuse_graph_)(std::shared_ptr&) = nullptr; bool (*fn_can_fuse_n_)(const Node*) = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/register_interface.cpp b/torch/csrc/jit/codegen/cuda/register_interface.cpp index f340a903131d..284ee05420a1 100644 --- a/torch/csrc/jit/codegen/cuda/register_interface.cpp +++ b/torch/csrc/jit/codegen/cuda/register_interface.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -20,7 +21,7 @@ class RegisterInterface { auto ptr = getFuserInterface(); ptr->fn_compile_n_ = &compileCudaFusionGroup; ptr->fn_run_n_s_ = &runCudaFusionGroup; - ptr->fn_fuse_graph = &CudaFuseGraph; + ptr->fn_fuse_graph_ = &CudaFuseGraph; ptr->fn_can_fuse_n_ = &isFusableCudaFusionGroup; RegisterProfilingNode(canFuseNode); diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 43cc24e8e29e..02ead1d6fa80 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -945,43 +946,45 @@ struct to_ir { } void emitDelete(const Delete& stmt) { - if (stmt.expr().kind() == TK_SUBSCRIPT) { - Subscript subscript(stmt.expr()); - const List& subscript_exprs = subscript.subscript_exprs(); - if (subscript_exprs[0].kind() == TK_SLICE_EXPR) { - throw ErrorReport(stmt.range()) - << "del statements only support deletion at a single index, " - "slicing is not supported" - " (see https://github.com/pytorch/pytorch/issues/31430)"; - } - const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); - const SourceRange& val_range = subscript.value().range(); - Value* idx = emitExpr(subscript_exprs[0]); - Value* val = sv->asValue(val_range, method); - - // If val is a class instance, this is a method call to a type-specific - // implementation of del defined in a __delitem__ method. - if (auto cls = val->type()->cast()) { - if (!cls->findMethod("__delitem__")) { - throw ErrorReport(stmt.range()) - << "Class does not define __delitem__"; + for (const auto& target : stmt.targets()) { + if (target.kind() == TK_SUBSCRIPT) { + Subscript subscript(target); + const List& subscript_exprs = subscript.subscript_exprs(); + if (subscript_exprs[0].kind() == TK_SLICE_EXPR) { + throw ErrorReport(target.range()) + << "del statements only support deletion at a single index, " + "slicing is not supported" + " (see https://github.com/pytorch/pytorch/issues/31430)"; } + const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); + const SourceRange& val_range = subscript.value().range(); + Value* idx = emitExpr(subscript_exprs[0]); + Value* val = sv->asValue(val_range, method); + + // If val is a class instance, this is a method call to a type-specific + // implementation of del defined in a __delitem__ method. + if (auto cls = val->type()->cast()) { + if (!cls->findMethod("__delitem__")) { + throw ErrorReport(target.range()) + << "Class does not define __delitem__"; + } - // Use MethodValue to call the method to handle recursion. - MethodValue(val, "__delitem__") - .call(stmt.range(), method, {idx}, {}, 0); + // Use MethodValue to call the method to handle recursion. + MethodValue(val, "__delitem__") + .call(stmt.range(), method, {idx}, {}, 0); + } else { + auto node = graph->create(aten::Delete, {val, idx}, 0) + ->setSourceRange(target.range()); + graph->insertNode(node); + } + } else if (target.kind() == TK_VAR) { + Var var(target); + environment_stack->removeVar(var.name(), /*check_if_removed=*/true); } else { - auto node = graph->create(aten::Delete, {val, idx}, 0) - ->setSourceRange(stmt.range()); - graph->insertNode(node); + throw ErrorReport(target.range()) + << "del statements are only supported for deleting" + " list and dict items and variables"; } - } else if (stmt.expr().kind() == TK_VAR) { - Var var(stmt.expr()); - environment_stack->removeVar(var.name(), /*check_if_removed=*/true); - } else { - throw ErrorReport(stmt.range()) - << "del statements are only supported for deleting" - " list and dict items and variables"; } } @@ -3441,7 +3444,25 @@ struct to_ir { } else { AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get())); } + // TODO for now let's deal with TupleType first. Ideally all list, tensor, + // string, and tuple slicing should be same (tugsbayasgalan) + if (sliceable->type()->cast()) { + std::vector> tuple_args; + // since we are only dealing with tuple slicing for now, we try to keep + // tuple args seperate for now + tuple_args.reserve(3); + + start ? tuple_args.emplace_back(start) + : tuple_args.emplace_back(c10::nullopt); + end ? tuple_args.emplace_back(end) + : tuple_args.emplace_back(c10::nullopt); + step ? tuple_args.emplace_back(step) + : tuple_args.emplace_back(c10::nullopt); + + return emitTupleSlice(loc, args[0], tuple_args); + } + // TODO this needs to be cleaned for list slicing // Default value for start is 0. if (!start) { start = graph->insertConstant(0, loc); @@ -3451,19 +3472,6 @@ struct to_ir { if (end) { args.emplace_back(loc, "end", end); } - if (sliceable->type()->cast()) { - if (step) { - // TODO: add support for slicing tuples with a step - throw ErrorReport(loc) - << "Unsupported operation: slicing tuples with a step isn't supported"; - } - - if (end) { - return emitTupleSlice(loc, args[0], args[1], /*end*/ args[2]); - } else { - return emitTupleSlice(loc, args[0], args[1], c10::nullopt); - } - } if (!step) { step = graph->insertConstant(1, loc); @@ -3826,28 +3834,37 @@ struct to_ir { Value* emitTupleSlice( const SourceRange& loc, const NamedValue& tuple_val, - const NamedValue& beg_val, - const at::optional& end_val) { + const std::vector>& tuple_args) { auto tuple_type = tuple_val.value(*graph)->type()->expect(); - int64_t beg = getAdjTupleIndex( - loc, - tuple_type, - getSliceInd(beg_val.value(*graph), loc), - /*allow_out_of_bounds*/ true); - int64_t end; int64_t tuple_len = tuple_type->elements().size(); + auto beg_val = tuple_args[0]; + auto end_val = tuple_args[1]; + auto step = tuple_args[2]; + + int64_t step_size = 1; + if (step) { + auto val = toIValue(step->value(*graph)); + TORCH_CHECK(val->isInt(), "Step size should always be an integer"); + step_size = val->to(); + } + + int64_t beg = std::numeric_limits::max(); + if (beg_val) { + beg = getAdjTupleIndex( + loc, tuple_type, getSliceInd(beg_val->value(*graph), loc), true); + } + + int64_t end = std::numeric_limits::max(); if (end_val) { end = getAdjTupleIndex( loc, tuple_type, getSliceInd(end_val->value(*graph), loc), true); - } else { - end = tuple_len; } - // slicing does not throw out of bounds errors - end = std::min(std::max((int64_t)0, end), tuple_len); - beg = std::min(std::max((int64_t)0, beg), tuple_len); + + int64_t num_values = slice_indices_adjust(tuple_len, &beg, &end, step_size); return graph - ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end)) + ->insertNode(graph->createTupleSlice( + tuple_val.value(*graph), beg, step_size, num_values)) ->output(); } @@ -3871,19 +3888,25 @@ struct to_ir { auto s_tuple_val = sv->asTupleValue(val_range, method)->asValue(val_range, method); const SliceExpr& slice = SliceExpr(subscript_exprs[0]); + std::vector> tuple_args; + tuple_args.reserve(3); auto begin = NamedValue(val_range, "begin", emitExpr(Expr(slice.startOr(0)))); + tuple_args.emplace_back(begin); if (slice.end().present()) { auto end = NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); - auto tupleSliceValue = - emitTupleSlice(val_range, s_tuple_val, begin, end); - return std::make_shared(tupleSliceValue); + tuple_args.emplace_back(end); + } else { - auto tupleSliceValue = - emitTupleSlice(val_range, s_tuple_val, begin, c10::nullopt); - return std::make_shared(tupleSliceValue); + tuple_args.emplace_back(c10::nullopt); } + // pushing step_size to match the tuple_args + tuple_args.emplace_back(c10::nullopt); + + auto tupleSliceValue = + emitTupleSlice(val_range, s_tuple_val, tuple_args); + return std::make_shared(tupleSliceValue); } else { return std::make_shared(emitBasicSlice( range, sv->asValue(val_range, method), subscript_exprs)); diff --git a/torch/csrc/jit/frontend/parser.cpp b/torch/csrc/jit/frontend/parser.cpp index 1f5e43fff149..c079e99893a7 100644 --- a/torch/csrc/jit/frontend/parser.cpp +++ b/torch/csrc/jit/frontend/parser.cpp @@ -558,10 +558,11 @@ struct ParserImpl { return parseFunction(/*is_method=*/in_class); } case TK_DELETE: { - L.expect(TK_DELETE); - auto expr = parseExp(); + auto range = L.next().range; + auto targets = + parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp); L.expect(TK_NEWLINE); - return Delete::create(expr); + return Delete::create(range, targets); } case TK_WITH: { return parseWith(); diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index e33d93f37566..389ed6d003db 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -1120,11 +1120,11 @@ struct Delete : public Stmt { explicit Delete(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_DELETE); } - Expr expr() const { - return Expr(subtree(0)); + List targets() const { + return subtree(0); } - static Delete create(const Expr& value) { - return Delete(Compound::create(TK_DELETE, value.range(), {value})); + static Delete create(const SourceRange& range, const List& targets) { + return Delete(Compound::create(TK_DELETE, range, {targets})); } }; diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index b055d29164a5..000bea53e0fc 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -486,6 +486,7 @@ void AliasDb::analyzeImpl(Node* node) { return analyzeGradOf(node); // TODO: think more about TensorExpr alias correctness case prim::TensorExprGroup: + case prim::StaticSubgraph: case prim::Constant: case prim::AutogradZero: case prim::AutogradAdd: @@ -524,6 +525,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::SetAttr: return analyzeSetAttr(node); case prim::profile_optional: + case prim::profile_ivalue: case prim::profile: makePointerTo(node->output(), node->inputs().at(0)); return; diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index fe79091c946f..65b410d82069 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1606,17 +1606,25 @@ Node* Graph::createTupleIndex( return n; } -Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) { - auto n = create(prim::TupleSlice, {tup}); - auto tuple_type = tup->type()->expect(); - n->i_(attr::beg, beg); - n->i_(attr::end, end); - std::vector output_types; - for (auto i = beg; i < end; ++i) { - output_types.push_back(tuple_type->elements().at(i)); - } - auto tt = TupleType::create(std::move(output_types)); - n->output()->setType(tt); +Node* Graph::createTupleSlice( + Value* tup, + int64_t beg, + int64_t step_size, + int64_t num_values) { + std::vector new_vals; + TupleTypePtr tt = tup->type()->expect(); + new_vals.reserve(num_values); + + int64_t i = beg; + for (int64_t j = 0; j < num_values; ++j) { + auto idx = insertConstant(IValue(static_cast(i))); + auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i])); + + new_vals.push_back(tupleIndex->output()); + i += step_size; + } + + auto n = createTuple(new_vals); return n; } @@ -2053,6 +2061,16 @@ Node* ProfileOptionalOp::allocNewInstance(Graph* g) { return new ProfileOptionalOp(g, {nullptr}); } +void ProfileIValueOp::cloneFrom(Node* other_) { + Node::cloneFrom(other_); + auto other = other_->cast(); + this->callback_ = other->getCallback(); +} + +Node* ProfileIValueOp::allocNewInstance(Graph* g) { + return new ProfileIValueOp(g, {nullptr}); +} + TypePtr NamedValue::type() const { if (value_) { return value_->type(); @@ -2061,8 +2079,9 @@ TypePtr NamedValue::type() const { } } -constexpr Symbol ProfileOp::Kind; -constexpr Symbol ProfileOptionalOp::Kind; +const Symbol ProfileOp::Kind = ::c10::prim::profile; +const Symbol ProfileOptionalOp::Kind = ::c10::prim::profile_optional; +const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue; OperatorSet::OperatorSet(std::initializer_list sig_literals) { for (const char* sig : sig_literals) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 9db2dbdf2516..7587451d9fd4 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -440,7 +440,7 @@ struct TORCH_API Node { // instructions lowered by the interpreter and not run in the optimized graph bool notExecutedOp() const { return kind_ == prim::Constant || kind_ == prim::profile || - kind_ == prim::profile_optional; + kind_ == prim::profile_optional || kind_ == prim::profile_ivalue; } // Graphs @@ -1122,7 +1122,11 @@ struct Graph { Value* tup, Value* idx, const TypePtr& output_type); - TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end); + TORCH_API Node* createTupleSlice( + Value* tup, + int64_t beg, + int64_t step_size, + int64_t num_values); TORCH_API Node* createEnumName(Value* e); TORCH_API Node* createEnumValue(Value* e); TORCH_API Node* createList( @@ -1326,7 +1330,7 @@ inline const Graph* Value::owningGraph() const { /************* All nodes not required to be defined before Graph **************/ struct ProfileOp : public Node { - static constexpr Symbol Kind = ::c10::prim::profile; + static const Symbol Kind; ProfileOp(Graph* graph, std::function&)> callback) : Node(graph, ::c10::prim::profile), callback_(std::move(callback)) {} @@ -1346,7 +1350,7 @@ struct ProfileOp : public Node { }; struct TORCH_API ProfileOptionalOp : public Node { - static constexpr Symbol Kind = ::c10::prim::profile_optional; + static const Symbol Kind; ProfileOptionalOp( Graph* graph, std::function&)> callback) @@ -1368,6 +1372,28 @@ struct TORCH_API ProfileOptionalOp : public Node { std::function&)> callback_; }; +struct TORCH_API ProfileIValueOp : public Node { + static const Symbol Kind; + ProfileIValueOp( + Graph* graph, + std::function&)> callback) + : Node(graph, ::c10::prim::profile_ivalue), callback_(callback) {} + + void cloneFrom(Node* other_) override; + Node* allocNewInstance(Graph* g) override; + + const std::function&)>& getCallback() const { + return callback_; + } + + void setCallback(std::function&)> callback) { + callback_ = callback; + } + + private: + std::function&)> callback_; +}; + // execute a Python function, used for Ops we can't optimize but that we want to // optimize around // diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index 1b93d28e2e1a..bacd8cf29bd2 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -394,7 +394,17 @@ class InsertObserversHelper { // are observed bool shouldObserve( Node* n, - const std::unordered_set& block_observed_values) { + const std::unordered_set& block_observed_values, + QuantType quant_type) { + // Check whether node output uses can be quantized, eg cat followed by + // linear op + for (Value* v : n->outputs()) { + for (const auto& use : v->uses()) { + if (useQuantizable(use, quant_type)) { + return true; + } + } + } if (isPropagateQuantSingleInputOp(n)) { return isObserved(n->input(0), block_observed_values); } else if (isPropagateQuantBinaryOp(n)) { @@ -1528,7 +1538,8 @@ InsertObserversHelper::insertObserversFor( // If the node is one of the propagate quant node, e.g. // aten::cat, we should observe its output only // if the input of the node is observed - if (observer_opt && shouldObserve(n, block_observed_values)) { + if (observer_opt && + shouldObserve(n, block_observed_values, quant_type_)) { recordObserved( v, *observer_opt, values_to_observe, block_observed_values); } diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index aaaaf6185dde..0d2c1c20b555 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -987,7 +987,7 @@ std::tuple InsertQuantDeQuantHelper:: v->debugName(), " exists."); QParamVector qparams; - c10::QScheme qscheme; + c10::QScheme qscheme = c10::kPerTensorAffine; auto observer_module = module.attr(observer_name.value()).toModule(); auto scalar_type = observer_module.attr("dtype"); diff --git a/torch/csrc/jit/passes/reconstruct_scopes.cpp b/torch/csrc/jit/passes/reconstruct_scopes.cpp deleted file mode 100644 index 15aa5863fbf1..000000000000 --- a/torch/csrc/jit/passes/reconstruct_scopes.cpp +++ /dev/null @@ -1,206 +0,0 @@ -#include -#include - -namespace torch { -namespace jit { - -class ReconstructScopesPass { - public: - ReconstructScopesPass(const Module& m, Graph& g, std::string p) - : root_module_(m), - graph_(g), - prefix_(std::move(p)), - class_types_are_not_unique_(false){}; - void run(); - - private: - const Module& root_module_; - Graph& graph_; - std::string prefix_; - - // This boolean indicates whether there are two submodules of the same - // class type. This issue may occur in a scripted module and make it - // difficult to exactly track module information corresponding to each - // Node* after inlining the graph. Consider the following example: - - // class A(nn.Module): - // def __init__(self): - // super(A, self).__init__() - - // def forward(self, x): - // return x + 1 - - // class B(nn.Module): - // def __init__(self): - // super(B, self).__init__() - // self.A0 = A() - // self.A1 = A() - - // def forward(self, x): - // return self.A0(x) + self.A1(x) - - // m_traced = torch.jit.trace(B(), torch.Tensor([1])) - // m_scripted = torch.jit.script(B()) - - // In m_traced, self.A0 and self.A1 have different class types, but in - // m_scripted, self.A0 and self.A1 have the same class types. Therefore, - // it is difficult to distinguish 'A0' and 'A1' in the module hierarchy - // after the graph is inlined. In this case, we add a warning to let - // users know that the debugging information may be incomplete. - bool class_types_are_not_unique_; - - std::unordered_map func_to_module_; - std::unordered_map module_names_; - - void visitBlock(Block* b, const std::string& root_scope_string); - void visitNode(Node* n, const std::string& root_scope_string); - - std::string getModuleTypeName( - const Module& module, - const std::string& prefix); - void constructFunctionToModuleMap(const Module& module); - void constructRelativeNamesForModules( - const Module& module, - const std::string& prefix); - - std::string getScopeString(const InlinedCallStackEntry& frame) const; - - void appendSourceRangeInfo( - std::string& scopeString, - const InlinedCallStackEntry& frame) const; -}; - -void ReconstructScopesPass::constructFunctionToModuleMap(const Module& module) { - for (const auto& method : module.get_methods()) { - Function* func_ptr = &method.function(); - if (!class_types_are_not_unique_ && - func_to_module_.find(func_ptr) != func_to_module_.end()) { - class_types_are_not_unique_ = true; - } - func_to_module_[func_ptr] = module._ivalue(); - } - for (const Module& m : module.children()) { - constructFunctionToModuleMap(m); - } -} - -std::string ReconstructScopesPass::getModuleTypeName( - const Module& module, - const std::string& prefix) { - std::string moduleType = module.type()->str(); - size_t lastDotIndex = moduleType.rfind('.'); - if (lastDotIndex != std::string::npos) { - moduleType = moduleType.substr(lastDotIndex + 1); - } - return prefix + "(" + moduleType + ")"; -} - -void ReconstructScopesPass::constructRelativeNamesForModules( - const Module& module, - const std::string& prefix) { - module_names_[module._ivalue()] = getModuleTypeName(module, prefix); - for (const NameModule& s : module.named_children()) { - constructRelativeNamesForModules( - s.value, module_names_[module._ivalue()] + "." + s.name); - } -} - -void ReconstructScopesPass::appendSourceRangeInfo( - std::string& scopeString, - const InlinedCallStackEntry& frame) const { - SourceRange r = std::get<1>(frame); - if (r.source()) { - if (auto orig = r.source()->findSourceRangeThatGenerated(r)) { - r = *orig; - } - } - if (auto file_line_col = r.file_line_col()) { - std::string filename; - size_t line, col; - std::tie(filename, line, col) = *file_line_col; - scopeString += "<" + filename + ":" + c10::to_string(line) + ":" + - c10::to_string(col) + ">"; - } -} - -std::string ReconstructScopesPass::getScopeString( - const InlinedCallStackEntry& frame) const { - Function* f = std::get<0>(frame); - if (!func_to_module_.count(f)) { - return ""; - } - auto m = func_to_module_.at(f); - if (!module_names_.count(m)) { - return ""; - } - std::string scopeString = module_names_.at(m) + "." + f->name(); - - // When class types are not unique, the module information may be - // incomplele. In this case, we add source range information, - // which can be helpful for debugging purposes. - if (class_types_are_not_unique_) { - appendSourceRangeInfo(scopeString, frame); - } - return scopeString; -} - -void ReconstructScopesPass::visitNode( - Node* n, - const std::string& root_scope_string) { - for (Block* b : n->blocks()) { - visitBlock(b, root_scope_string); - } - ScopePtr sc = c10::make_intrusive(); - if (!n->callstack()) { - sc = sc->push(Symbol::scope(root_scope_string)); - } else { - for (const auto& frame : (*n->callstack())->vec()) { - auto name = getScopeString(frame); - GRAPH_UPDATE("Adding a scope ", name, " for node ", *n); - sc = sc->push(Symbol::scope(name)); - } - } - n->setScope(sc); - GRAPH_UPDATE("Updated node: ", *n); -} - -void ReconstructScopesPass::visitBlock( - Block* b, - const std::string& root_scope_string) { - for (Node* n : b->nodes()) { - visitNode(n, root_scope_string); - } -} - -void ReconstructScopesPass::run() { - GRAPH_DUMP("Graph before reconstructing scope", &graph_); - func_to_module_.clear(); - module_names_.clear(); - - constructFunctionToModuleMap(root_module_); - constructRelativeNamesForModules(root_module_, prefix_); - - if (class_types_are_not_unique_) { - TORCH_WARN( - "It seems that the module contain two instances of the same class type.\n", - "The current debugging program has not provided support for distinguishing ", - "the two instances of the same class type.\n", - "The module debugging information may be incomplete."); - } - - std::string root_scope_string = - getModuleTypeName(root_module_, prefix_) + ".forward"; - visitBlock(graph_.block(), root_scope_string); - GRAPH_DUMP("Graph after reconstructing scope", &graph_); -} - -void ReconstructScopes( - const Module& module, - Graph& g, - const std::string& prefix = "top") { - ReconstructScopesPass p(module, g, prefix); - p.run(); -} - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/reconstruct_scopes.h b/torch/csrc/jit/passes/reconstruct_scopes.h deleted file mode 100644 index b08655cb3741..000000000000 --- a/torch/csrc/jit/passes/reconstruct_scopes.h +++ /dev/null @@ -1,37 +0,0 @@ -/** \brief A pass to reconstruct scopes of nodes from their inline callstacks. - * - * The pass takes the root module and a graph and for every graph node with - * non-empty inline call-stack it computes the scope from this callstack. - * - * Callstack can be thought of as a stack of pointers to Function, and Function - * in a general case may not be a part of any module. That's why this pass - * requires a root module to be passed in - we can traverse all methods of the - * module and its submodules and then recognize these methods in callstacks. - * - * Scope can be thought of as a stack of strings, so we basically converting a - * pointer to Function to a string, or in other words trying to find a name for - * a function in this module hierarchy. - * - * The produced scopes look like: - * top.submod1.function1/top.submod1.subsubmod1.function2 - * - * 'top' is the name we use for the root module itself, and it can be customized - * with an optional third argument of the pass. - * - * The pass would not change anything if inlining has not been run on the graph. - */ -#pragma once - -#include -#include - -namespace torch { -namespace jit { - -TORCH_API void ReconstructScopes( - const Module& module, - Graph& g, - const std::string& prefix); - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index c53a71eb02e8..8a71e52db556 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -459,7 +459,7 @@ class TensorExprFuser { // fusion is done. inlineSmallFusionGroups(graph_->block()); GRAPH_DUMP("After inlining small fusion groups: ", graph_); - guardFusionGroupsAndRemoveOutputs(graph_->block()); + prepareFusionGroupAndGuardOutputs(graph_->block()); GRAPH_DUMP("After guarding fusion groups: ", graph_); removeTensorTypeSpecializations(graph_->block()); GRAPH_DUMP("After removing tensor type specializations: ", graph_); @@ -739,15 +739,29 @@ class TensorExprFuser { }; // clang-format on - // Value is only supported if operands are floats. - if (node->isMemberOf(float_only_operator_set)) { - for (const Value* v : node->inputs()) { - if (auto const& tt = v->type()->cast()) { - auto const& st = tt->scalarType(); - if (!st || !isFloatingType(*st)) { - return false; - } - } else if (!v->type()->cast()) { + for (const Value* v : node->inputs()) { + if (auto const& tt = v->type()->cast()) { + auto const& st = tt->scalarType(); + + // All tensors must be typed. + if (!st) { + return false; + } + + // Byte tensors introduce too many corner cases in type promotion. + // Better not to try to handle them. + if (*st == c10::ScalarType::Byte) { + return false; + } + + // These operators only support floats, because integer divisors need to + // raise ZeroDivisionError. + if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) { + return false; + } + } else if (node->isMemberOf(float_only_operator_set)) { + // Check scalar operands of float-only ops. + if (!v->type()->cast()) { return false; } } @@ -763,17 +777,10 @@ class TensorExprFuser { } bool canHandle(Node* node) { - REQ(node->kind() != prim::Constant); REQ(disable_shape_checks_ || allShapesAreKnown(node)); REQ(isFusableOnDevice(node)); - // Don't include nodes whose inputs are tensor constants - we cannot handle - // them at the moment. - // TODO: actually support tensor constants and remove this. for (Value* input : node->inputs()) { - if (input->node()->kind() == prim::Constant) { - REQ(!input->type()->cast()) - } if (auto const& tt = input->type()->cast()) { auto st = tt->scalarType(); if (!st) { @@ -975,11 +982,32 @@ class TensorExprFuser { } } - void guardFusionGroupsAndRemoveOutputs(Block* block) { + // TODO: support constant tensors instead of setting them as input + void liftTensorConstantsFromFusionGroups(Node* fusion_group) { + auto subgraph = SubgraphUtils::getSubgraph(fusion_group); + WithInsertPoint guard(fusion_group); + for (auto it = subgraph->block()->nodes().begin(); + it != subgraph->block()->nodes().end(); + ++it) { + auto n = *it; + if (n->kind() == prim::Constant && + n->output()->type()->cast()) { + auto constant = + fusion_group->owningGraph()->insertConstant(*toIValue(n->output())); + fusion_group->addInput(constant); + auto inputToGraph = subgraph->addInput(); + inputToGraph->setType(n->output()->type()); + n->output()->replaceAllUsesWith(inputToGraph); + it.destroyCurrent(); + } + } + } + + void prepareFusionGroupAndGuardOutputs(Block* block) { std::vector fusion_groups; for (Node* n : block->nodes()) { for (Block* b : n->blocks()) { - guardFusionGroupsAndRemoveOutputs(b); + prepareFusionGroupAndGuardOutputs(b); } if (n->kind() == prim::TensorExprGroup) { fusion_groups.push_back(n); @@ -987,6 +1015,7 @@ class TensorExprFuser { } for (Node* fusion_group : fusion_groups) { removeOutputsUsedOnlyInSize(fusion_group); + liftTensorConstantsFromFusionGroups(fusion_group); guardFusionGroup(fusion_group); } } @@ -1028,16 +1057,7 @@ Operation createTensorExprOp(const Node* node) { std::make_shared(node->g(attr::Subgraph)); return [kernel](Stack* stack) { RECORD_FUNCTION("TensorExpr", std::vector()); - if (!tensorexpr::fallbackAllowed()) { - kernel->run(*stack); - return 0; - } - - try { - kernel->run(*stack); - } catch (const std::runtime_error& e) { - kernel->fallback(*stack); - } + kernel->run(*stack); return 0; }; } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 5f88a8a6c79d..663c9ab06a52 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -54,7 +54,6 @@ #include #include #include -#include #include #include #include @@ -340,16 +339,6 @@ void initJITBindings(PyObject* module) { subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name); subgraph_rewriter.runOnGraph(g); }) - .def( - "_jit_pass_reconstruct_scopes", - [](script::Module& module, - std::shared_ptr& g, - const std::string& prefix) { - ReconstructScopes(module, *g, prefix); - }, - py::arg("module"), - py::arg("graph"), - py::arg("prefix") = "top") .def( "_jit_pass_remove_inplace_ops", [](const std::shared_ptr& g) { return RemoveInplaceOps(g); }) diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 99b439aa185f..34ca7585be67 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -61,6 +61,8 @@ inline IValue toIValue( py::object toPyObject(IValue ivalue); +IValue toTypeInferredIValue(py::handle input); + // The PythonFutureWrapper for ivalue::Future // // NB: VISIBILITY_HIDDEN is for silencing compiling error, @@ -119,32 +121,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper // vector, but Future does not acquire GIL on destruction. auto pf = std::make_shared(std::move(cb)); -#ifdef USE_C10D_NCCL - // This callback is only used by NCCL backend, so skip this code on other - // backends and avoid importing cuda dependency. - // By default, assume that the input value is or can be casted into a tensor - // vector that has exactly one tensor. - auto record_stream_cb = [](const at::IValue& value, - const c10::Stream& stream) { - if (value.isTensorList() || value.isPyObject()) { - std::vector tensors; - if (value.isTensorList()) { - tensors = value.toTensorVector(); - } else { - pybind11::gil_scoped_acquire gil; - py::object obj = torch::jit::toPyObject(value); - tensors = torch::jit::toIValue( - obj, c10::ListType::create(c10::TensorType::get())) - .toTensorVector(); - } - TORCH_INTERNAL_ASSERT(tensors.size() == 1, "expected exactly 1 tensor"); - at::cuda::CUDAStream cuda_stream(stream); - c10::cuda::CUDACachingAllocator::recordStream( - tensors[0].storage().data_ptr(), cuda_stream); - } - }; - fut->setRecordStreamCallback(record_stream_cb); -#endif + fut->setDataPtrExtractor(&PythonFutureWrapper::dataPtrExtractor); return std::make_shared(fut->then( // Capture a copy of the ivalue::Future instead of the `this` pointer @@ -241,6 +218,23 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper std::shared_ptr getPtr() { return shared_from_this(); } + + // This callback is only used by subclasses of Future that deal with CUDA, + // in order to register the pointers on the right streams with the caching + // allocator. + static std::vector> dataPtrExtractor( + const at::IValue& value) { + if (value.isPyObject()) { + pybind11::gil_scoped_acquire gil; + py::object obj = torch::jit::toPyObject(value); + // FIXME This could fail. As a fallback we could try to pickle the + // object, since the pickler might support broader types and it is able + // to extract the tensors from the object as a vector. + auto new_value = torch::jit::toTypeInferredIValue(obj); + return at::ivalue::Future::defaultDataPtrExtractor(new_value); + } + return at::ivalue::Future::defaultDataPtrExtractor(value); + }; }; // error reporting: when reporting user-caused errors, these functions should @@ -719,6 +713,15 @@ inline IValue toIValue( const auto& attrType = classType->getAttribute(slot); const auto& attrName = classType->getAttributeName(slot); + if (!py::hasattr(obj, attrName.c_str())) { + throw py::cast_error(c10::str( + "Tried to cast object to type ", + type->repr_str(), + " but object", + " was missing attribute ", + attrName)); + } + const auto& contained = py::getattr(obj, attrName.c_str()); userObj->setSlot(slot, toIValue(contained, attrType)); } diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index f5cdea1e7eb4..6e68fe9ebec3 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -697,6 +697,16 @@ void initPythonIRBindings(PyObject* module_) { } return py::none(); }) + .def( + "varyingSizes", + [](Type& t) -> py::object { + if (auto ptt = t.expect()) { + if (auto s = ptt->sizes().sizes()) { + return py::cast(s.value()); + } + } + return py::none(); + }) .def( "strides", [](Type& t) -> py::object { diff --git a/torch/csrc/jit/python/python_tree_views.cpp b/torch/csrc/jit/python/python_tree_views.cpp index 1355352c8278..1e622bda379a 100644 --- a/torch/csrc/jit/python/python_tree_views.cpp +++ b/torch/csrc/jit/python/python_tree_views.cpp @@ -200,9 +200,10 @@ void initTreeViewBindings(PyObject* module) { r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type)); })); - py::class_(m, "Delete").def(py::init([](const Expr& expr) { - return Delete::create(expr); - })); + py::class_(m, "Delete") + .def(py::init([](const SourceRange& range, std::vector targets) { + return Delete::create(range, wrap_list(range, std::move(targets))); + })); py::class_(m, "WithItem") .def(py::init([](const SourceRange& range, const Expr& target, Var* var) { diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index ef0f2dae9e0e..5d88264a2f2c 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -320,6 +320,7 @@ struct CanEmitInline { // by the later BailOut in createBailoutBlock and its jf_index // will become invalid. v->node()->kind() != prim::TensorExprGroup && + v->node()->kind() != prim::StaticSubgraph && v->node()->kind() != prim::CudaFusionGroup && v->node()->kind() != prim::FusionGroup && v->node()->kind() != prim::BailOut && v->uses().size() == 1 && @@ -791,6 +792,9 @@ struct CodeImpl { } else if (node->cast()) { profile_function_table_.push_back( node->cast()->getCallback()); + } else if (node->cast()) { + profile_function_table_.push_back( + node->cast()->getCallback()); } else { TORCH_INTERNAL_ASSERT(false); } @@ -945,6 +949,7 @@ struct CodeImpl { case prim::BailOut: emitBailOut(node); break; + case prim::profile_ivalue: case prim::profile_optional: case prim::profile: emitProfile(node); @@ -1412,10 +1417,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { auto t = input.toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X + i]; auto expected_type = expected->cast(); - if (t.defined() && - (!frames.back().symbols2dims.bindSymbolicShapes( - t.sizes(), expected_type->symbolic_sizes()) || - !expected_type->matchTensor(t))) { + if (t.defined() && !expected_type->matchTensor(t)) { push(stack, false); break; } @@ -1607,10 +1609,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { + bool pre_sampled = false; if (!frame.record_function && at::hasCallbacks() && - at::isRecordFunctionEnabled()) { + at::shouldRunRecordFunction(&pre_sampled)) { auto rec_fn = std::make_unique( - at::RecordScope::TORCHSCRIPT_FUNCTION); + at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled); if (rec_fn->isActive()) { if (rec_fn->needsInputs()) { rec_fn->before( diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 0756d6b58e9f..1964679fda19 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -239,12 +239,14 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::CudaFusionGroup, // optimization pass adds it prim::CudaFusionGuard, // optimization pass adds it prim::TensorExprGroup, // optimization pass adds it + prim::StaticSubgraph, // optimization pass adds it prim::Load, // used in interpreter only prim::MMTreeReduce, // used as an optimization prim::MMBatchSide, // used as an optimization prim::Store, // used in interpreter only prim::profile, // used in interpreter only prim::profile_optional, // used in interpreter only + prim::profile_ivalue, // used in interpreter only prim::TypeCheck, // used in interpreter only prim::FallbackGraph, // converted into prim::CallFunction @@ -275,6 +277,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::CudaFusionGroup, prim::DifferentiableGraph, prim::TensorExprGroup, + prim::StaticSubgraph, prim::FunctionalGraph, prim::Constant, prim::Uninitialized, @@ -303,6 +306,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::SetAttr, prim::profile, prim::profile_optional, + prim::profile_ivalue, prim::TypeCheck, prim::Print, prim::CallFunction, diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index dc6f50350bd0..31750636d762 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -35,6 +35,18 @@ C10_DEFINE_bool( true, "If this flag is set to false TorchScript will be using the legacy/original executor"); +constexpr size_t kDefaultNumProfiledRuns = 1; +constexpr size_t kDefaultBailoutDepth = 20; + +C10_DEFINE_int64( + torch_jit_num_profiled_runs, + kDefaultNumProfiledRuns, + "Number of profiling runs"); +C10_DEFINE_int64( + torch_jit_bailout_depth, + kDefaultBailoutDepth, + "Number of re-specializations"); + namespace torch { namespace jit { @@ -46,21 +58,32 @@ static std::atomic executor_mode{true}; static std::atomic profiling_mode{true}; #endif -static std::atomic num_profiled_runs{1}; -static std::atomic bailout_depth{20}; // NOLINT +static std::atomic num_profiled_runs{kDefaultNumProfiledRuns}; +static std::atomic bailout_depth{kDefaultBailoutDepth}; std::atomic& getProfilingMode() { return profiling_mode; } + std::atomic& getExecutorMode() { return executor_mode; } std::atomic& getNumProfiledRuns() { + // Initialize num_profiled_runs from command-line flag. + static const size_t init = []() { + return num_profiled_runs = FLAGS_torch_jit_num_profiled_runs; + }(); + (void)init; // Silence clang-tidy. return num_profiled_runs; } std::atomic& getBailoutDepth() { + // Initialize bailout_depth from command-line flag. + static const size_t init = []() { + return bailout_depth = FLAGS_torch_jit_bailout_depth; + }(); + (void)init; // Silence clang-tidy. return bailout_depth; } diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index f031d957449b..d9bffa7e4644 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -822,6 +822,11 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__contains__.int_list(int[] l, int item) -> bool"), + listContains, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::__contains__.str_list(str[] l, str item) -> bool"), diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 0be346246656..8361fb3b3385 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -29,7 +29,15 @@ RegisterOperators reg( {Operator( prim::profile, [](const Node* node) -> Operation { - auto callback = node->cast()->getCallback(); + return [](Stack* stack) { + AT_ERROR( + "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT + }; + }, + aliasAnalysisSpecialCase()), + Operator( + prim::profile_ivalue, + [](const Node* node) -> Operation { return [](Stack* stack) { AT_ERROR( "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT @@ -711,10 +719,6 @@ RegisterOperators reg2({ // `listContains` is not implemented for non-primitive types // TODO: Add List[bool] once .to> doesn't throw an error - Operator( - "aten::__contains__.int_list(int[] l, int item) -> bool", - listContains, - aliasAnalysisFromSchema()), Operator( "aten::__contains__.float_list(float[] l, float item) -> bool", listContains, diff --git a/torch/csrc/jit/runtime/slice_indices_adjust.cpp b/torch/csrc/jit/runtime/slice_indices_adjust.cpp new file mode 100644 index 000000000000..e71d6ba94c9a --- /dev/null +++ b/torch/csrc/jit/runtime/slice_indices_adjust.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +namespace torch { +namespace jit { + +int64_t slice_indices_adjust( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step) { + TORCH_CHECK(step != 0, "List slice should have non-zero step") + TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds") + + // Comes from PySlice_Unpack. + if (*start == INT64_MAX) { + *start = (step < 0) ? INT64_MAX : 0; + } + if (*stop == INT64_MAX) { + *stop = (step < 0) ? INT64_MIN : INT64_MAX; + } + + // Comes from PySlice_AdjustIndices. + if (*start < 0) { + *start += length; + if (*start < 0) { + *start = (step < 0) ? -1 : 0; + } + } else if (*start >= length) { + *start = (step < 0) ? length - 1 : length; + } + + if (*stop < 0) { + *stop += length; + if (*stop < 0) { + *stop = (step < 0) ? -1 : 0; + } + } else if (*stop >= length) { + *stop = (step < 0) ? length - 1 : length; + } + + if (step < 0) { + if (*stop < *start) { + return (*start - *stop - 1) / (-step) + 1; + } + } else { + if (*start < *stop) { + return (*stop - *start - 1) / step + 1; + } + } + return 0; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/slice_indices_adjust.h b/torch/csrc/jit/runtime/slice_indices_adjust.h new file mode 100644 index 000000000000..ea1e9511769d --- /dev/null +++ b/torch/csrc/jit/runtime/slice_indices_adjust.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { + +// Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +// 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software +// Foundation; All Rights Reserved +// +// Stolen (with appropriate modifications) by @agolynski +// (https://github.com/pytorch/pytorch/pull/33019) from cpython repo +// Objects/sliceobject.c with comment: this is harder to get right than you +// might think +// +// This adjusts indexes according to python list semantics and returns number +// of elements in the resulting list. +TORCH_API int64_t slice_indices_adjust( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp new file mode 100644 index 000000000000..fc8defe2dcfb --- /dev/null +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -0,0 +1,254 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +void createFusionGroups(Block* block, AliasDb* aliasDb); + +void fuseStaticSubgraphs(std::shared_ptr graph) { + PrepareGraphForStaticRuntime(graph); + auto aliasDb = torch::make_unique(graph); + createFusionGroups(graph->block(), aliasDb.get()); + torch::jit::EliminateDeadCode(graph); +} + +Operation createStaticSubgraphRuntime(const Node* node) { + auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph)); + auto runtime = std::make_shared(g); + auto num_inputs = runtime->get_inference_module()->input_regs.size(); + return [runtime, num_inputs](Stack* stack) { + RECORD_FUNCTION("Static Runtime", std::vector()); + auto inps = torch::jit::last(stack, num_inputs); + std::vector t_inputs; + t_inputs.reserve(num_inputs); + for (const auto& inp : inps) { + t_inputs.emplace_back(inp.toTensor()); + } + torch::jit::drop(stack, num_inputs); + auto outputs = runtime->run(t_inputs); + for (auto& o : outputs) { + push_one(*stack, std::move(o)); + } + return 0; + }; +} + +RegisterOperators StaticSubgraphOps({torch::jit::Operator( + prim::StaticSubgraph, + createStaticSubgraphRuntime, + AliasAnalysisKind::INTERNAL_SPECIAL_CASE)}); + +#define REQ(cond) \ + if (!(cond)) { \ + GRAPH_DEBUG("Failed cond " #cond "\n"); \ + return false; \ + } + +bool canHandle(Node* node) { + for (Value* input : node->inputs()) { + // TODO checks + } + + auto kind = node->kind(); + if (kind.is_prim()) { + REQ(kind == prim::TupleConstruct || kind == prim::ListConstruct || + kind == prim::StaticSubgraph); + return true; + } + const Operator& op = node->getOperator(); + auto analysis = op.aliasAnalysisKind(); + if (AliasAnalysisKind::PURE_FUNCTION == analysis || + (AliasAnalysisKind::FROM_SCHEMA == analysis && + !node->schema().is_mutable())) { + return true; + } + return false; +} + +bool canMerge(Node* consumer, Node* producer, AliasDb* aliasDb) { + // Only fuse within a block + REQ(consumer->owningBlock() == producer->owningBlock()); + + // Symbolic checks + REQ(canHandle(producer) || producer->kind() == prim::StaticSubgraph); + TORCH_INTERNAL_ASSERT( + consumer->kind() == prim::StaticSubgraph || canHandle(consumer)); + + // Alias checks + REQ(aliasDb->couldMoveBeforeTopologically(producer, consumer)); + + // Ops that return aliases can only be folded if this is the only use. + if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze || + producer->kind() == prim::ConstantChunk) { + for (auto& use : producer->output(0)->uses()) { + REQ(use.user == consumer); + } + } + + return true; +} + +Node* getOrCreateStaticSubgraph(Node* n, AliasDb* aliasDb) { + if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::StaticSubgraph) { + return n; + } + GRAPH_UPDATE("Creating a static subgraph::Group node from: ", *n); + return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( + n, prim::StaticSubgraph, *aliasDb); +} + +value_list sortReverseTopological(ArrayRef inputs, Block* b) { + value_list result; + for (auto i : inputs) { + if (i->node()->owningBlock() == b) { + result.push_back(i); + } + } + // Sort in reverse topological order + std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { + return a->node()->isAfter(b->node()); + }); + return result; +} + +static void debugDumpFusionGroup(const std::string& msg, Node* n) { + GRAPH_DEBUG(msg, *n); + if (n->kind() == prim::StaticSubgraph) { + GRAPH_DEBUG(*n->g(attr::Subgraph)); + } +} + +c10::optional tryMerge( + Node* fusion_group, + Node* to_merge, + AliasDb* aliasDb) { + if (!canMerge(fusion_group, to_merge, aliasDb)) { + return c10::nullopt; + } + + std::vector nodes_to_merge = {to_merge}; + + if (to_merge->kind() == aten::cat) { + Node* listconstruct = to_merge->input(0)->node(); + nodes_to_merge.push_back(listconstruct); + } + + // First, try to move all the nodes we want to fuse next to the fusion + // group. + Node* move_point = fusion_group; + for (auto n : nodes_to_merge) { + GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n)); + if (!aliasDb->moveBeforeTopologicallyValid(n, move_point)) { + GRAPH_UPDATE("Failed to move because of AliasDb checks!"); + return c10::nullopt; + } + move_point = n; + } + + // Now all the nodes that we're going to fuse are moved next to the fusion + // group, so we can safely merge them into the fusion group subgraph. + fusion_group = getOrCreateStaticSubgraph(fusion_group, aliasDb); + + for (auto n : nodes_to_merge) { + GRAPH_UPDATE("Merging ", getHeader(n)); + SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( + n, fusion_group, *aliasDb); + } + return fusion_group; +} + +std::pair createFusionGroup( + Node* fusion_node, + AliasDb* aliasDb) { + fusion_node = getOrCreateStaticSubgraph(fusion_node, aliasDb); + + GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n"); + auto inputs = + sortReverseTopological(fusion_node->inputs(), fusion_node->owningBlock()); + for (auto input : inputs) { + debugDumpFusionGroup("Current fusion group: ", fusion_node); + GRAPH_DEBUG("Trying to merge: ", *input->node()); + if (auto maybe_fusion_group = + tryMerge(fusion_node, input->node(), aliasDb)) { + // we successfully merged, so the new group's `inputs` may have + // changed. So rescan the new group for more merging opportunities. + return std::make_pair( + maybe_fusion_group.value()->reverseIterator(), true); + } + } + + return std::make_pair(++fusion_node->reverseIterator(), false); +} + +std::pair scanNode(Node* n, AliasDb* aliasDb) { + GRAPH_DEBUG("Considering node:", *n); + + if (!canHandle(n)) { + return std::make_pair(++n->reverseIterator(), false); + } + + return createFusionGroup(n, aliasDb); +} + +void createFusionGroups(Block* block, AliasDb* aliasDb) { + bool any_changed = true; + while (any_changed) { + any_changed = false; + for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { + bool changed; + std::tie(it, changed) = scanNode(*it, aliasDb); + any_changed |= changed; + } + } + + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + createFusionGroups(b, aliasDb); + } + } + + // Try to merge adjacent fusion groups together. Because we have only merged + // by looking at graph inputs, without this we would not attempt to merge + // adjacent fusion groups that don't have a depdency on each other + + std::vector initial_fusion_groups; + for (Node* n : block->nodes()) { + if (n->kind() == prim::StaticSubgraph) { + initial_fusion_groups.push_back(n); + } + } + + Node* prev_fusion_group = + initial_fusion_groups.size() ? initial_fusion_groups[0] : nullptr; + + for (size_t i = 1; i < initial_fusion_groups.size(); ++i) { + // Try merging the just created fusion group into the previous one. + // If it did not work, then put the previous fusion group into + // fusion_groups vector - we will not touch it anymore in this loop. + // If merging suceeded, save the merged group as the "previous" fusion + // group so that we can try to merge the next one into it. + + Node* fusion_group = initial_fusion_groups[i]; + debugDumpFusionGroup( + "Trying to merge into the previous fusion group: ", prev_fusion_group); + if (auto merged_fusion_group = + tryMerge(prev_fusion_group, fusion_group, aliasDb)) { + prev_fusion_group = *merged_fusion_group; + debugDumpFusionGroup( + "Successfully merged into the previous fusion group: ", + prev_fusion_group); + } else { + GRAPH_DEBUG("Cannot merge into the previous fusion group"); + prev_fusion_group = fusion_group; + } + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/fusion.h b/torch/csrc/jit/runtime/static/fusion.h new file mode 100644 index 000000000000..5f0e30b8505b --- /dev/null +++ b/torch/csrc/jit/runtime/static/fusion.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +TORCH_API void fuseStaticSubgraphs(std::shared_ptr graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 07d41fb1f642..ffa2ab4f7ec4 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1,26 +1,36 @@ #include +#include #include #include #include #include #include +#include #include #include #include #include +#include #include namespace torch { namespace jit { -namespace { -void OptimizeGraph(std::shared_ptr& graph) { +void PrepareGraphForStaticRuntime(std::shared_ptr graph) { Inline(*graph); ConstantPropagation(graph); Canonicalize(graph); ConstantPropagation(graph); RemoveTensorMutation(graph); ConstantPropagation(graph); + EliminateDeadCode(graph); +} + +namespace { +void OptimizeGraph(std::shared_ptr& graph) { + PrepareGraphForStaticRuntime(graph); + FuseInferenceOpsForSparseNN(graph); + ConstantPropagation(graph); } void CheckGraphEligibility(const std::shared_ptr& graph) { @@ -160,23 +170,30 @@ LivenessMap(const std::shared_ptr& graph) { std::unordered_set GetOptimizableValues( const std::shared_ptr& graph) { - std::unordered_set is_out_of_place; - std::unordered_set is_not_out_of_place; + std::unordered_set can_reuse; + // values used by unsupported ops (as either inputs or outputs) + // these need to be removed from "can_reuse" after analyzing all nodes + std::unordered_set cannot_reuse; for (const auto& n : graph->nodes()) { - for (const auto& container : {n->inputs(), n->outputs()}) { - for (const auto& v : container) { - if (canRunOutOfPlace(n)) { - is_out_of_place.insert(v); - } else { - is_not_out_of_place.insert(v); - } + for (const auto& v : n->inputs()) { + if (canRunOutOfPlace(n) && canReuseInputs(n)) { + can_reuse.insert(v); + } else { + cannot_reuse.insert(v); + } + } + for (const auto& v : n->outputs()) { + if (canRunOutOfPlace(n) && canReuseOutputs(n)) { + can_reuse.insert(v); + } else { + cannot_reuse.insert(v); } } } - for (auto v : is_not_out_of_place) { - is_out_of_place.erase(v); + for (auto v : cannot_reuse) { + can_reuse.erase(v); } - return is_out_of_place; + return can_reuse; } size_t AssignRegisters( @@ -408,6 +425,12 @@ std::vector StaticRuntime::run( c10::IValue StaticRuntime::run( const std::vector& args, const std::unordered_map& kwargs) { + // We assume inference workloads, so we do not need + // autograd. Enabling this is a significant win on dispatcher + // overhead because it saves a round of dispatch for at least some + // functions, such as resize_ and resize_as_. + at::AutoNonVariableTypeMode non_var_type_mode(true); + if (planner_) { planner_->allocate(); } @@ -519,6 +542,10 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( const int main_runs) { TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1); + // See comment on above use of AutoNonVariableTypeMode for + // explanation. + at::AutoNonVariableTypeMode non_var_type_mode(true); + IndividualMetrics results; results.total_time = 0.0; results.time_per_node.resize(nodes_.size(), 0); diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 2eef530e778b..21ce26bf488d 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -83,6 +83,9 @@ struct TORCH_API InferenceModule { void init(); }; +TORCH_API void PrepareGraphForStaticRuntime( + std::shared_ptr g); + inline TORCH_API std::shared_ptr PrepareForStaticRuntime( const torch::jit::Module& m, InferenceModuleOptions opts = InferenceModuleOptions()) { diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp index 3088e3bc5f36..4799b5bff974 100644 --- a/torch/csrc/jit/runtime/static/init.cpp +++ b/torch/csrc/jit/runtime/static/init.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include namespace torch { @@ -68,8 +70,23 @@ void initStaticRuntimeBindings(PyObject* module) { [](std::shared_ptr g) { return StaticRuntime(PrepareForStaticRuntime(g)); }) - .def("_jit_to_static_runtime", [](const torch::jit::Module& m) { - return StaticRuntime(PrepareForStaticRuntime(m)); + .def( + "_jit_to_static_runtime", + [](const torch::jit::Module& m) { + return StaticRuntime(PrepareForStaticRuntime(m)); + }) + .def( + "_fuse_to_static_runtime", + [](torch::jit::Module& module) { + module.eval(); + module = freeze_module(module); + + Method method = module.get_method("forward"); + auto graph = method.graph(); + fuseStaticSubgraphs(graph); + }) + .def("_fuse_to_static_runtime", [](std::shared_ptr g) { + fuseStaticSubgraphs(g); }); } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index ab0640bf75f1..57db79699e07 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -6,12 +6,6 @@ namespace torch { namespace jit { -namespace { -inline at::Tensor create_empty_from(const at::Tensor& t) { - return at::empty({0}, t.options()); -} -} // namespace - C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor); bool canRunOutOfPlace(Node* n) { @@ -19,19 +13,30 @@ bool canRunOutOfPlace(Node* n) { return SROperatorRegistry()->Has(op_name); } +bool canReuseInputs(Node* n) { + auto op_name = std::string(n->kind().toQualString()); + DCHECK(SROperatorRegistry()->Has(op_name)); + return SROperatorRegistry()->Create(op_name)->CanReuseInput(); +} + +bool canReuseOutputs(Node* n) { + auto op_name = std::string(n->kind().toQualString()); + DCHECK(SROperatorRegistry()->Has(op_name)); + return SROperatorRegistry()->Create(op_name)->CanReuseOutput(); +} + // TODO: expand to include all view producing ops, mostly in // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp bool canRunNatively(Node* n) { // In alphabetical order const static std::unordered_set native_nodes{ "aten::flatten", + "aten::narrow", "aten::permute", "aten::reshape", "aten::slice", "aten::transpose", "aten::to", - "aten::reshape", - "aten::slice", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct"}; @@ -45,6 +50,34 @@ bool canRunNatively(Node* n) { return true; } +// TODO: PLEASE DON'T COPY PASTE THIS, this is copy pasted +// generated code to unblock, need to make this nicer +struct static_add final : public at::native::structured_add_out { + static_add(at::Tensor& output) : output_(output) {} + void set_output( + int64_t output_idx, + at::IntArrayRef sizes, + at::IntArrayRef strides, + at::TensorOptions options, + at::DimnameList names) override { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx == 0); + // NB: do NOT use resize_output as it will complain if not zero sized. + at::native::resize_(output_, sizes); + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + output_.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + output_.unsafeGetTensorImpl()->empty_tensor_restride( + *options.memory_format_opt()); + } + } + const at::Tensor& maybe_get_output(int64_t output_idx) override { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx == 0); + return output_; + } + at::Tensor& output_; +}; + REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { auto in0_t = p_node->Input(0, reg).toTensor(); @@ -54,8 +87,9 @@ REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { p_node->Output(0, reg) = create_empty_from(in0_t); } auto out_t = p_node->Output(0, reg).toTensor(); - out_t.resize_({0}); - at::native::add_out(out_t, in0_t, in1_t, in2_s); + static_add op{out_t}; + op.meta(in0_t, in1_t, in2_s); + op.impl(out_t, in0_t, in1_t, in2_s); }; }); @@ -373,6 +407,37 @@ getNativeOperation(Node* n) { p_node->Output(0, reg) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i); }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) { + return [](const ProcessedNode* p_node, std::vector& reg) { + auto self = p_node->Input(0, reg).toTensor(); // self + auto dim = p_node->Input(1, reg).toInt(); // dim + int64_t start = 0; + if (p_node->Input(2, reg).isScalar()) { + start = p_node->Input(2, reg).toInt(); + } else { + auto t = p_node->Input(2, reg).toTensor(); + start = t.item(); + } + auto length = p_node->Input(3, reg).toInt(); // length + TORCH_CHECK( + self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + auto cur_size = self.size(dim); + if (start != cur_size && start < 0) { // start being the end is valid, but + // not a valid dim specification. + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + p_node->Output(0, reg) = + at::native::slice(self, dim, start, start + length, 1); + }; } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) { return [](const ProcessedNode* p_node, std::vector& reg) { DCHECK(p_node->input_regs().size() == 5); diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index dabff008aa20..467dac282668 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -10,21 +10,48 @@ using SROperator = std::function&)>; using SROpFunctor = std::function; struct SROperatorFunctor { - virtual SROperator Generate(Node*) = 0; + virtual SROperator Generate(Node*) { + std::function&)> out; + return out; + } + virtual bool CanReuseInput() { + return false; + } + virtual bool CanReuseOutput() { + return false; + } virtual ~SROperatorFunctor() = default; }; C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor); -#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \ - struct SROperatorFunctor_##id : public SROperatorFunctor { \ - const SROpFunctor fn = __VA_ARGS__; \ - SROperator Generate(Node* n) override { \ - return fn(n); \ - } \ - }; \ + +// TODO: reuse_inp reuse_out can be deprecated with further analysis +// try to avoid this API. +#define REGISTER_OPERATOR_FUNCTOR_OPT(name, id, reuse_inp, reuse_out, ...) \ + struct SROperatorFunctor_##id : public SROperatorFunctor { \ + const SROpFunctor fn = __VA_ARGS__; \ + bool CanReuseInput() override { \ + return reuse_inp; \ + } \ + bool CanReuseOutput() override { \ + return reuse_out; \ + } \ + SROperator Generate(Node* n) override { \ + return fn(n); \ + } \ + }; \ C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id); +#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \ + REGISTER_OPERATOR_FUNCTOR_OPT(name, id, true, true, __VA_ARGS__) + +inline at::Tensor create_empty_from(const at::Tensor& t) { + return at::empty({0}, t.options()); +} + bool canRunOutOfPlace(Node* n); +bool canReuseInputs(Node* n); +bool canReuseOutputs(Node* n); std::function&)> getOutOfPlaceOperation(Node* n); diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp new file mode 100644 index 000000000000..a75d187b2a49 --- /dev/null +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -0,0 +1,83 @@ +#include +#include + +namespace torch { +namespace jit { + +void ConcatAddMulReplaceNaNClip(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %y0 = aten::cat(%a, %b) + %y1 = aten::add(%y0, %c, %d) + %y2 = aten::mul(%y1, %e) + %y3 = aten::nan_to_num(%y2, %f, %g, %h) + %res = aten::clamp(%y3, %i, %j) + return (%res))IR"; + std::string pattern2 = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %y0 = aten::cat(%a, %b) + %y1 = aten::add(%y0, %c, %d) + %y2 = aten::mul(%y1, %e) + %y3 = aten::nan_to_num_(%y2, %f, %g, %h) + %res = aten::clamp(%y3, %i, %j) + return (%res))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %res = fb::concat_add_mul_replacenan_clip(%c, %e, %a, %i, %j) + return (%res))IR"; + + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); + + fuse.RegisterRewritePattern(pattern2, fused_pattern); + fuse.runOnGraph(graph); +} + +void CastedBatchOneHotLengths(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g): + %y0 : Tensor = aten::to(%a, %b, %c, %c, %d) + %y1 : Tensor = fb::batch_one_hot_lengths(%y0, %e, %f) + %res : Tensor = aten::to(%y1, %g, %c, %c, %d) + return (%res))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g): + %res : Tensor = fb::casted_batch_one_hot_lengths(%a, %e, %f) + return (%res))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void ConcatBatchMatMulBatchGather(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %y0 : Tensor = aten::stack(%a, %b) + %y1 : Tensor = aten::transpose(%y0, %b, %c) + %y2 : Tensor = aten::bmm(%y0, %y1) + %y3 : Tensor = aten::flatten(%y2, %d, %e) + %res : Tensor = aten::index_select(%y3, %b, %f) + return (%res))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %res : Tensor = fb::concat_batch_matmul_batch_gather(%f, %a) + return (%res))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { +#ifdef FBCODE_CAFFE2 + ConcatAddMulReplaceNaNClip(graph); + CastedBatchOneHotLengths(graph); + ConcatBatchMatMulBatchGather(graph); +#endif +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h new file mode 100644 index 000000000000..7cc9c52f7696 --- /dev/null +++ b/torch/csrc/jit/runtime/static/passes.h @@ -0,0 +1,9 @@ +#include + +namespace torch { +namespace jit { + +void FuseInferenceOpsForSparseNN(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index e9f9d27bf166..3291285b90dd 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index 3956d0283487..ed5dde0b08b0 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -108,7 +108,7 @@ class ScriptModuleDeserializer final { public: ScriptModuleDeserializer( std::shared_ptr cu, - std::unique_ptr reader) + std::shared_ptr reader) : compilation_unit_(std::move(cu)), reader_(std::move(reader)), source_importer_( @@ -128,7 +128,7 @@ class ScriptModuleDeserializer final { IValue readArchive(const std::string& archive_name); std::shared_ptr compilation_unit_; - std::unique_ptr reader_; + std::shared_ptr reader_; c10::optional device_; std::vector constants_table_; SourceImporter source_importer_; @@ -175,7 +175,6 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { return obj; } }; - return readArchiveAndTensors( archive_name, type_resolver, obj_loader, device_, *reader_.get()); } @@ -257,8 +256,7 @@ Module ScriptModuleDeserializer::deserialize( } if (reader_->hasRecord("model.json")) { #if !defined(C10_MOBILE) && !defined(C10_DISABLE_LEGACY_IMPORT) - return torch::jit::LEGACY_deserialize( - compilation_unit_, std::move(reader_), device_); + return torch::jit::LEGACY_deserialize(compilation_unit_, reader_, device_); #else AT_ERROR("Legacy model format is not supported on mobile."); #endif @@ -271,7 +269,6 @@ Module ScriptModuleDeserializer::deserialize( rewriteQuantizedConvForBC(m); return m; } - } // namespace Module import_ir_module( @@ -323,7 +320,7 @@ Module load( } Module load( - std::unique_ptr rai, + std::shared_ptr rai, c10::optional device, ExtraFilesMap& extra_files) { // Verify that we're loading a zip archive and not a torch.save pickle archive @@ -347,7 +344,7 @@ Module load( " produced by `torch.jit.save()`"); } - auto reader = torch::make_unique(std::move(rai)); + auto reader = std::make_shared(std::move(rai)); auto cu = std::make_shared(); ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); diff --git a/torch/csrc/jit/serialization/import.h b/torch/csrc/jit/serialization/import.h index 543a1ca32aaf..cbfb765a6350 100644 --- a/torch/csrc/jit/serialization/import.h +++ b/torch/csrc/jit/serialization/import.h @@ -55,13 +55,13 @@ TORCH_API Module load( c10::optional device = c10::nullopt, ExtraFilesMap& extra_files = default_extra_files); -/// Loads a serialized `Module` from the given `rai`. +/// Loads a serialized `Module` from the given shared_ptr `rai`. /// /// The reader adapter, which is for customized input stream, must contain a /// serialized `Module`, exported either via `ScriptModule.save()` in /// Python or `torch::jit::ExportModule` in C++. TORCH_API Module load( - std::unique_ptr rai, + std::shared_ptr rai, c10::optional device = c10::nullopt, ExtraFilesMap& extra_files = default_extra_files); diff --git a/torch/csrc/jit/serialization/import_legacy.cpp b/torch/csrc/jit/serialization/import_legacy.cpp index 7a8279e0199c..40e035b82090 100644 --- a/torch/csrc/jit/serialization/import_legacy.cpp +++ b/torch/csrc/jit/serialization/import_legacy.cpp @@ -40,7 +40,7 @@ class ScriptModuleDeserializer final { public: ScriptModuleDeserializer( std::shared_ptr cu, - std::unique_ptr reader, + std::shared_ptr reader, const c10::optional& device) : compilation_unit_(std::move(cu)), reader_(std::move(reader)), @@ -76,7 +76,7 @@ class ScriptModuleDeserializer final { std::shared_ptr sourceLoader(const std::string& qualifier); std::shared_ptr compilation_unit_; - std::unique_ptr reader_; + std::shared_ptr reader_; c10::optional device_; // Legacy only tensor can be a constant. std::vector constant_table_; @@ -383,7 +383,7 @@ Module ScriptModuleDeserializer::LEGACY_convertModule( Module LEGACY_deserialize( std::shared_ptr cu, - std::unique_ptr reader, + std::shared_ptr reader, const c10::optional& device) { ScriptModuleDeserializer deserializer( std::move(cu), std::move(reader), device); diff --git a/torch/csrc/jit/serialization/import_legacy.h b/torch/csrc/jit/serialization/import_legacy.h index 64f8a7da1968..a26182810959 100644 --- a/torch/csrc/jit/serialization/import_legacy.h +++ b/torch/csrc/jit/serialization/import_legacy.h @@ -16,7 +16,7 @@ struct CompilationUnit; // Deserializes a model in legacy format. Module LEGACY_deserialize( std::shared_ptr cu, - std::unique_ptr reader, + std::shared_ptr reader, const c10::optional& device); } // namespace jit diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 4473b0cb50dd..6a557e6e53f3 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -209,7 +209,7 @@ class TORCH_API Pickler { // the left of a '::', its type cannot be deduced by the compiler so one must // explicitly instantiate the template, i.e. push(int) works, push(int) // does not) - static constexpr size_t kBufferSize = 256; + static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256; template void push(typename std::common_type::type value) { const char* begin = reinterpret_cast(&value); diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index b8f16c50e05f..7f1f09032555 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -30,11 +30,7 @@ RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList:: void RegisterCodeGenList::AddStmtFactoryMethod( const std::string& name, const StmtFactoryMethod& stmt_factory_method) { - auto insert_ret = - stmt_factory_methods_.insert(std::make_pair(name, stmt_factory_method)); - if (!insert_ret.second) { - throw std::runtime_error("Duplicated CodeGen names: " + name); - } + stmt_factory_methods_[name] = stmt_factory_method; } std::unique_ptr CreateCodeGen( diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index 6bf3456e3b85..e16a9e2c5d31 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -82,7 +82,7 @@ class CodeGen::BufferArg { BufferArg(const Placeholder& buffer) : var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {} BufferArg(Tensor* tensor) - : var_(tensor->buf()->base_handle()), dtype_(tensor->body()->dtype()) {} + : var_(tensor->buf()->base_handle()), dtype_(tensor->buf()->dtype()) {} BufferArg(const VarHandle& var) : var_(var.node()), dtype_(var.dtype()), isVar_(true) {} diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 7b8a4c194782..e7fbd376d563 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -124,6 +125,14 @@ inline c10::Half div_value(c10::Half lhs, c10::Half rhs) { return lhs / rhs; } +template +To raw_bitcast(const From& src) { + TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation"); + To storage; + std::memcpy(&storage, &src, sizeof(From)); + return reinterpret_cast(storage); +} + class SimpleIREvaluator : public CodeGen, public IRVisitor { public: template @@ -573,6 +582,57 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor { } } + template + std::vector bitcastValues(const Dtype& src_dtype, const Value& v) { + const std::vector& src_values = v.as_vec(); + std::vector dst_values(src_values.size()); + for (int i = 0; i < src_dtype.lanes(); ++i) { + dst_values[i] = raw_bitcast(src_values[i]); + } + return dst_values; + } + + template + void doBitCastFromSrc( + const Dtype& src_dtype, + const Dtype& dst_dtype, + const Value& v) { + switch (dst_dtype.scalar_type()) { +#define DST_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + this->value_ = Value(bitcastValues(src_dtype, v)); \ + break; + // bool/half not supported + AT_FORALL_SCALAR_TYPES(DST_TYPE_CASE); +#undef DST_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const BitCast* v) override { + const Expr* src_value = v->src_value(); + src_value->accept(this); + Dtype dst_dtype = v->dtype(); + Dtype src_dtype = src_value->dtype(); + if (src_dtype.byte_size() != dst_dtype.byte_size()) { + throw malformed_input("lane mismatch in Cast", v); + } + if (src_dtype != dst_dtype) { + switch (src_dtype.scalar_type()) { +#define SRC_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + doBitCastFromSrc(src_dtype, dst_dtype, value_); \ + break; + // bool/half not supported + AT_FORALL_SCALAR_TYPES(SRC_TYPE_CASE); +#undef SRC_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + } + TORCH_API void visit(const For* v) override { const Expr* var_node = v->var(); v->start()->accept(this); diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 9b8dd23db0b1..cd05333656c0 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -31,6 +31,7 @@ enum IRNodeType { kCompareSelect, kLet, kCast, + kBitCast, kBroadcast, kRamp, kPolynomial, diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 7eeea564a6a7..6fe4bf0e2ebd 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -28,6 +28,7 @@ inline int getPrecedence(IRNodeType ty) { case kPrimitive: return 0; case kCast: + case kBitCast: return 2; case kAdd: case kSub: @@ -81,6 +82,34 @@ ExprHandle cast(const ExprHandle& src_value) { return Cast::make(Dtype(ToDtype(), src_value.dtype().lanes()), src_value); } +// This is a bitwise cast, akin to bitcast in LLVM +class BitCast : public ExprNode { + public: + const Expr* src_value() const { + return src_value_; + } + static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { + return ExprHandle(new BitCast(dtype, src_value.node())); + } + BitCast(Dtype dtype, const Expr* src_value) + : ExprNodeBase(dtype, kBitCast), src_value_(src_value) { + TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size()); + } + + bool isConstant() const override { + return src_value_->isConstant(); + } + + private: + const Expr* src_value_; +}; + +template +ExprHandle bitcast(const ExprHandle& src_value) { + return BitCast::make( + Dtype(ToDtype(), src_value.dtype().lanes()), src_value); +} + // Represent the expression node for binary operators. // A CRTP pattern to share common code among the operators. template diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 5f0889842b1e..ddbe88bb2c8f 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -139,6 +139,15 @@ const Expr* IRMutator::mutate(const Cast* v) { return new Cast(v->dtype(), src_value_new); } +const Expr* IRMutator::mutate(const BitCast* v) { + const Expr* src_value = v->src_value(); + const Expr* src_value_new = src_value->accept_mutator(this); + if (src_value_new == v->src_value()) { + return v; + } + return new BitCast(v->dtype(), src_value_new); +} + const Expr* IRMutator::mutate(const Var* v) { return v; } diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index 0913da0e972d..773920cb52fa 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); #undef IMM_DECLARE class Cast; +class BitCast; class Var; class Buf; class Ramp; @@ -75,6 +76,7 @@ class TORCH_API IRMutator { AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE virtual const Expr* mutate(const Cast* v); + virtual const Expr* mutate(const BitCast* v); virtual const Expr* mutate(const Var* v); virtual const Expr* mutate(const Buf* v); virtual const Expr* mutate(const Ramp* v); diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 848bd70cf5c7..1df2f96671df 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -596,6 +596,11 @@ std::string to_string(const Tensor* t) { return "(null tensor)\n"; } std::ostringstream oss; + if (!t->body()) { + oss << "Tensor " << t->buf()->name_hint() << " = " << *t->ElementStmt() + << "\n"; + return oss.str(); + } oss << "Tensor " << t->buf()->name_hint() << "("; for (size_t i = 0; i < t->ndim(); i++) { if (i != 0) { diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index ae97a6200d8b..772a28c77add 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -79,6 +79,9 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); void IRVisitor::visit(const Cast* v) { v->src_value()->accept(this); } +void IRVisitor::visit(const BitCast* v) { + v->src_value()->accept(this); +} void IRVisitor::visit(const Var* v) {} void IRVisitor::visit(const Ramp* v) { diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index 3f5f05229c16..8353da680edb 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE) #undef IMM_DECLARE class Cast; +class BitCast; class Var; class Buf; class Ramp; @@ -74,6 +75,7 @@ class TORCH_API IRVisitor { #undef IMM_PRINT_VISIT virtual void visit(const Cast* v); + virtual void visit(const BitCast* v); virtual void visit(const Var* v); virtual void visit(const Buf* v); virtual void visit(const Ramp* v); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 50f285104d95..f80bfc6745fa 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1,5 +1,7 @@ #include +#include +#include #include #include #include @@ -434,17 +436,68 @@ ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) { return scalars_.at(v->unique()); } -ExprHandle promoteIntegerToFloat(const ExprHandle& e) { +ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) { auto scalarType = static_cast(e.dtype().scalar_type()); if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) { return e; } - auto defaultType = static_cast( - c10::typeMetaToScalarType(c10::get_default_dtype())); - return Cast::make(Dtype(defaultType, e.dtype().lanes()), e); + + auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype()); + + // We intend to promote Integers to floating-point types + TORCH_INTERNAL_ASSERT( + !c10::isIntegralType(defaultType, /*includeBool*/ true)); + + return Cast::make( + Dtype( + static_cast(defaultType), e.dtype().lanes()), + e); +} + +ExprHandle promoteHalfToFloat(const ExprHandle& e) { + auto scalarType = static_cast(e.dtype().scalar_type()); + auto floatType = static_cast(tensorexpr::ScalarType::Float); + if (c10::isFloatingType(scalarType) && + (c10::elementSize(scalarType) < c10::elementSize(floatType))) { + return Cast::make( + Dtype(tensorexpr::ScalarType::Float, e.dtype().lanes()), e); + } else { + return e; + } +} + +ExprHandle promoteHalfToFloatAndIntegerToDefaultType(const ExprHandle& e) { + auto scalarType = static_cast(e.dtype().scalar_type()); + if (c10::isIntegralType(scalarType, /*includeBool*/ true)) { + return promoteIntegerToDefaultType(e); + } else { + return promoteHalfToFloat(e); + } +} + +bool TensorExprKernel::checkTypes( + const ScalarType highType, + const int typeConstraints) { + if (typeConstraints == kAllTypes) { + return true; + } + + if (is_integral(highType)) { + return (typeConstraints & kIntegralTypes) != 0; + } else if (is_floating_point(highType)) { + return (typeConstraints & kFloatingPointTypes) != 0; + } else if (highType == ScalarType::Bool) { + return (typeConstraints & kBoolType) != 0; + } + + // assume JIT not supporting complex and qint yet + TORCH_INTERNAL_ASSERT((typeConstraints & (kQintTypes | kComplexTypes)) == 0); + return false; } -void TensorExprKernel::promoteInputs(std::vector& inputs) { +void TensorExprKernel::promoteInputs( + std::vector& inputs, + const int typeConstraints) { if (inputs.empty()) { return; } @@ -455,6 +508,10 @@ void TensorExprKernel::promoteInputs(std::vector& inputs) { highType = promoteTypes(highType, input.dtype().scalar_type()); } + if (!checkTypes(highType, typeConstraints)) { + throw unsupported_dtype(); + } + for (ExprHandle& e : inputs) { e = promoteToDtype(e, highType); } @@ -561,19 +618,20 @@ std::vector TensorExprKernel::valueShape( Tensor* TensorExprKernel::computeOneOperand( const std::string& name, const torch::jit::Value* v, - const std::function& innerExpr) { + const std::function& innerExpr, + const int checkParamTypes) { auto const& n = v->node(); auto const& shape = valueShape(n->inputs()[0]); return Compute( name, c10::fmap(shape), - [this, v, innerExpr](const std::vector& axes) { + [this, v, innerExpr, checkParamTypes]( + const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices)}; - - promoteInputs(inputs); + promoteInputs(inputs, checkParamTypes); ExprHandle compute = innerExpr(inputs[0]); return demoteOutput(compute, n->output()); }); @@ -668,7 +726,8 @@ Tensor* TensorExprKernel::computeThreeOperand( const torch::jit::Value* v, const std::function< ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& - innerExpr) { + innerExpr, + bool promote_inputs) { auto const& n = v->node(); std::vector> shapes; for (size_t idx = 0; idx < 3; idx++) { @@ -679,7 +738,7 @@ Tensor* TensorExprKernel::computeThreeOperand( return Compute( name, c10::fmap(shape), - [this, v, innerExpr](const std::vector& axes) { + [this, v, innerExpr, promote_inputs](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { @@ -688,7 +747,9 @@ Tensor* TensorExprKernel::computeThreeOperand( tensorOrConstant(n->inputs()[2], indices), }; - promoteInputs(inputs); + if (promote_inputs) { + promoteInputs(inputs); + } ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]); return demoteOutput(compute, n->output()); }); @@ -789,7 +850,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::div: { return computeTwoOperand( "aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - return promoteIntegerToFloat(lhs) / promoteIntegerToFloat(rhs); + return promoteIntegerToDefaultType(lhs) / + promoteIntegerToDefaultType(rhs); }); } break; @@ -917,26 +979,31 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { const ExprHandle& in, const ExprHandle& min, const ExprHandle& max) { + auto cast = [&](const ExprHandle& e) { + return Cast::make(in.dtype(), e); + }; + if (noMin && noMax) { return in; } else if (noMin) { - return CompareSelect::make(in, max, max, in, kGT); + auto cmax = cast(max); + return CompareSelect::make(in, cmax, cmax, in, kGT); } else if (noMax) { - return CompareSelect::make(in, min, min, in, kLT); + auto cmin = cast(min); + return CompareSelect::make(in, cmin, cmin, in, kLT); } else { - return CompareSelect::make( - in, - min, - min, - CompareSelect::make(in, max, max, in, kGT), - kLT); + auto cmax = cast(max); + auto cmin = cast(min); + auto mm = CompareSelect::make(in, cmin, cmin, in, kLT); + return CompareSelect::make(mm, cmax, cmax, mm, kGT); } - }); + }, + false /* promote_inputs */); } break; case aten::sigmoid: { return computeOneOperand("aten_sigmoid", v, [](const ExprHandle& a) { - return sigmoid(promoteIntegerToFloat(a)); + return sigmoid(promoteIntegerToDefaultType(a)); }); } break; @@ -961,25 +1028,25 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::log: { return computeOneOperand("aten_log", v, [](const ExprHandle& a) { - return log(promoteIntegerToFloat(a)); + return log(promoteIntegerToDefaultType(a)); }); } break; case aten::log10: { return computeOneOperand("aten_log10", v, [](const ExprHandle& a) { - return log10(promoteIntegerToFloat(a)); + return log10(promoteIntegerToDefaultType(a)); }); } break; case aten::log1p: { return computeOneOperand("aten_log1p", v, [](const ExprHandle& a) { - return log1p(promoteIntegerToFloat(a)); + return log1p(promoteIntegerToDefaultType(a)); }); } break; case aten::log2: { return computeOneOperand("aten_log2", v, [](const ExprHandle& a) { - return log2(promoteIntegerToFloat(a)); + return log2(promoteIntegerToDefaultType(a)); }); } break; @@ -989,37 +1056,38 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::expm1: { - return computeOneOperand( - "aten_expm1", v, [](const ExprHandle& a) { return expm1(a); }); + return computeOneOperand("aten_expm1", v, [](const ExprHandle& a) { + return expm1(promoteIntegerToDefaultType(a)); + }); } break; case aten::erf: { return computeOneOperand("aten_erf", v, [](const ExprHandle& a) { - return erf(promoteIntegerToFloat(a)); + return erf(promoteIntegerToDefaultType(a)); }); } break; case aten::erfc: { return computeOneOperand("aten_erfc", v, [](const ExprHandle& a) { - return erfc(promoteIntegerToFloat(a)); + return erfc(promoteIntegerToDefaultType(a)); }); } break; case aten::cos: { return computeOneOperand("aten_cos", v, [](const ExprHandle& a) { - return cos(promoteIntegerToFloat(a)); + return cos(promoteIntegerToDefaultType(a)); }); } break; case aten::sin: { return computeOneOperand("aten_sin", v, [](const ExprHandle& a) { - return sin(promoteIntegerToFloat(a)); + return sin(promoteIntegerToDefaultType(a)); }); } break; case aten::tan: { return computeOneOperand("aten_tan", v, [](const ExprHandle& a) { - return tan(promoteIntegerToFloat(a)); + return tan(promoteIntegerToDefaultType(a)); }); } break; @@ -1132,30 +1200,31 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::acos: { return computeOneOperand("aten_acos", v, [](const ExprHandle& a) { - return acos(promoteIntegerToFloat(a)); + return acos(promoteIntegerToDefaultType(a)); }); } break; case aten::asin: { return computeOneOperand("aten_asin", v, [](const ExprHandle& a) { - return asin(promoteIntegerToFloat(a)); + return asin(promoteIntegerToDefaultType(a)); }); } break; case aten::cosh: { - return computeOneOperand( - "aten_cosh", v, [](const ExprHandle& a) { return cosh(a); }); + return computeOneOperand("aten_cosh", v, [](const ExprHandle& a) { + return cosh(promoteIntegerToDefaultType(a)); + }); } break; case aten::sinh: { return computeOneOperand("aten_sinh", v, [](const ExprHandle& a) { - return sinh(promoteIntegerToFloat(a)); + return sinh(promoteIntegerToDefaultType(a)); }); } break; case aten::atan: { return computeOneOperand("aten_atan", v, [](const ExprHandle& a) { - return atan(promoteIntegerToFloat(a)); + return atan(promoteIntegerToDefaultType(a)); }); } break; @@ -1163,19 +1232,20 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { return computeTwoOperand( "aten_atan2", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return atan2( - promoteIntegerToFloat(lhs), promoteIntegerToFloat(rhs)); + promoteIntegerToDefaultType(lhs), + promoteIntegerToDefaultType(rhs)); }); } break; case aten::tanh: { return computeOneOperand("aten_tanh", v, [](const ExprHandle& a) { - return tanh(promoteIntegerToFloat(a)); + return tanh(promoteIntegerToDefaultType(a)); }); } break; case aten::sqrt: { return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) { - return sqrt(promoteIntegerToFloat(a)); + return tensorexpr::sqrt(promoteIntegerToDefaultType(a)); }); } break; @@ -1186,7 +1256,12 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::abs: { return computeOneOperand( - "aten_abs", v, [](const ExprHandle& a) { return fabs(a); }); + "aten_abs", + v, + [](const ExprHandle& a) { + return fabs(promoteHalfToFloatAndIntegerToDefaultType(a)); + }, + kIntegralTypes | kFloatingPointTypes | kBoolType); } break; case aten::ceil: { @@ -1231,7 +1306,13 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::frac: { return computeOneOperand( - "aten_frac", v, [](const ExprHandle& a) { return a - floor(a); }); + "aten_frac", + v, + [](const ExprHandle& a) { + auto aa = promoteHalfToFloat(a); + return aa - floor(aa); + }, + kFloatingPointTypes); } break; case aten::lgamma: { @@ -1874,6 +1955,88 @@ std::vector TensorExprKernel::getReductionAxes( return axes; } +template +std::vector reverse_sort_indices(const std::vector& v) { + // initialize original index locations + std::vector idx(v.size()); + iota(idx.begin(), idx.end(), 0); + + std::sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) { + return v[i1] > v[i2]; + }); + return idx; +} + +bool denseAndNonOverlapping( + at::ArrayRef sizes, + at::ArrayRef strides) { + return (strides == at::infer_dense_strides(sizes, strides)); +} + +Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { + const TensorTypePtr& tt = v->type()->expect(); + TORCH_INTERNAL_ASSERT(tensors_.count(v->unique())); + Tensor* tensor = tensors_[v->unique()]; + + TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes()); + const auto sizes = *tt->sizes().concrete_sizes(); + std::vector default_strides = TensorType::contiguousStridesOf(sizes); + TORCH_INTERNAL_ASSERT(tt->strides().concrete_sizes()); + const std::vector strides = *tt->strides().concrete_sizes(); + // All Tensors in NNC are layed out in default, contiguous layout. + // If the output is also default contiguous we don't need to do anything + if (strides == default_strides) { + return tensor; + } + // If the tensor is not dense or overlaps, we have + // no way of matching the profiled striding + if (!denseAndNonOverlapping(sizes, strides)) { + return tensor; + } + + auto dims = dimsFromSizes(sizesForValue(v)); + // We need to convert the output tensor so that its values are layed + // so that whene viewed from the output strides the values are correct. + // A contiguous Tensor of size(2, 3) with values 0-5 is layed out as: + // [0] [1] [2] [3] [4] [5] + // The same valued tensor with strides (2, 1) would be layed out like + // [0] [3] [1] [4] [2] [5] + // When we are doing the re-ordering of values into the output tensor, + // we are iterating per-element of the input, ad we are fixed + // in indexing in to the output tensor at [i, j] = val + // `val` we want here is equal to the indices for the output + // tensor that would have given the same position as the output + // The position is equal to the sum of stride[i] * index[i], + // and we can can calculate the equivalent indices in the + // output tensor strides by iteratively computing the index of + // the biggest stride: + // absolute = ... + // for stride in strides_from_largest_to_smallest: + // cur_idx = absolute // stride + // absolute = absolute % stride + + return Compute( + "output_1", dims, [&](const std::vector& axes_input) { + std::vector axes(axes_input.begin(), axes_input.end()); + auto absolute_position = IntImm::make(0); + for (size_t i = 0; i < axes.size(); ++i) { + absolute_position = + absolute_position + (IntImm::make(default_strides[i]) * axes[i]); + } + std::vector sorted_stride_indices = + reverse_sort_indices(strides); + std::vector new_axes(sorted_stride_indices.size()); + for (size_t stride_index : sorted_stride_indices) { + auto stride = strides[stride_index]; + auto index = Div::make(absolute_position, IntImm::make(stride)); + absolute_position = + Mod::make(absolute_position, IntImm::make(stride)); + new_axes[stride_index] = index; + } + return tensor->call(new_axes); + }); +} + void TensorExprKernel::compile() { KernelScope kernelScope(&kernelArena_); GRAPH_DUMP("TensorExprKernel graph:", graph_); @@ -1901,16 +2064,38 @@ void TensorExprKernel::compile() { } } + device_ = *pickDeviceType(graph_->inputs()); + // Move output operands from `tensors_` to `tensorOutputs_` for (const auto& output : graph_->outputs()) { if (!tensors_.count(output->unique())) { throw malformed_input("cannot find output Tensor"); } + // The "strided" tensor will be incorrect if used in NNC, + // since NNC views it as contiguous. Only convert it to the right + // strides at the end of the kernel (if already contiguous it's a no-op) + Tensor* properly_strided_output = convertOutputToCorrectStrides(output); + tensors_[output->unique()] = properly_strided_output; + const auto& tt = output->type()->expect(); + auto sizes = *tt->sizes().concrete_sizes(); + tensorOutputSizes_.push_back(sizes); + auto strides = *tt->strides().concrete_sizes(); + + // If the tensor is not dense or overlaps, we have + // no way of matching the profiled striding + if (denseAndNonOverlapping(sizes, strides)) { + tensorOutputStrides_.push_back(*tt->strides().concrete_sizes()); + } else { + tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes)); + } + tensorOutputs_.emplace_back(tensors_.at(output->unique())); + tensorOutputTensorOptions_.push_back( + c10::TensorOptions(tensorType(tensors_[output->unique()])) + .device(device_)); tensors_.erase(output->unique()); } - device_ = *pickDeviceType(graph_->inputs()); BackendType backendType = inferBackendTypeFromDevice(device_); Stmt* stmt = generateStmt(backendType); // Set up formal params (inputs, then outputs) for kernel. @@ -1927,36 +2112,34 @@ void TensorExprKernel::compile() { TensorExprKernel::TensorExprKernel(const std::shared_ptr& subgraph) : graph_(subgraph), code_(subgraph, "") { - if (!fallbackAllowed()) { + allow_fallback_ = fallbackAllowed(); + if (!allow_fallback_) { compile(); return; } + use_fallback_ = fallbackEnforced(); + if (use_fallback_) { + return; + } + try { compile(); } catch (...) { - fallback_ = true; + use_fallback_ = true; } } void TensorExprKernel::run(Stack& stack) { - if (fallbackEnforced()) { - fallback(stack); - return; - } - if (!fallbackAllowed()) { + if (!use_fallback_ && !allow_fallback_) { runKernel(stack); - return; - } - - if (fallback_) { - fallback(stack); - return; - } - try { - runKernel(stack); - } catch (...) { - fallback_ = true; + } else if (!use_fallback_ && allow_fallback_) { + try { + runKernel(stack); + } catch (...) { + fallback(stack); + } + } else { fallback(stack); } } @@ -1964,47 +2147,26 @@ void TensorExprKernel::run(Stack& stack) { std::vector TensorExprKernel::prepareRunArgs( const at::ArrayRef& inputs, std::vector& outputs) { - std::map varToSize; - std::vector runArgs; - for (size_t i = 0; i < inputs.size(); i++) { + runArgs.reserve(inputs.size() + tensorOutputs_.size()); + + for (size_t i = 0, e = inputs.size(); i < e; i++) { auto const& input = inputs[i]; if (input.isInt()) { runArgs.emplace_back((int32_t)input.toInt()); } else if (input.isDouble()) { runArgs.emplace_back((float)input.toDouble()); } else if (input.isTensor()) { - auto const& tensor = input.toTensor(); - runArgs.emplace_back(tensor.data_ptr()); - for (auto const& size : kernelArgs_[i].sizes()) { - int32_t s = tensor.sizes()[size.idx]; - runArgs.emplace_back(s); - varToSize[size.var.node()] = s; - } - for (auto const& stride : kernelArgs_[i].strides()) { - int32_t s = tensor.strides()[stride.idx]; - runArgs.emplace_back(s); - } + runArgs.emplace_back(input.toTensor().data_ptr()); } } - for (auto& o : tensorOutputs_) { - std::vector tensorSize; - for (const Expr* dim : o->dims()) { - auto it = varToSize.find(dim); - if (it != varToSize.end()) { - tensorSize.push_back(it->second); - } else { - const IntImm* s = dynamic_cast(dim); - if (!s) { - throw malformed_input("output expected Int", dim); - } - tensorSize.push_back(s->value()); - } - } - - outputs.push_back(at::empty( - tensorSize, c10::TensorOptions(tensorType(o)).device(device_))); + for (size_t i = 0, e = tensorOutputs_.size(); i < e; ++i) { + auto t = at::empty_strided( + tensorOutputSizes_[i], + tensorOutputStrides_[i], + tensorOutputTensorOptions_[i]); + outputs.push_back(t); runArgs.emplace_back(outputs.back().data_ptr()); } return runArgs; diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 54a876eb85c6..c969669e63d3 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -36,6 +36,16 @@ class TORCH_API TensorExprKernel { } private: + enum ElementType { + kAllTypes = 0, + kIntegralTypes = 1 << 0, + kFloatingPointTypes = 1 << 1, + kBoolType = 1 << 2, + kComplexTypes = 1 << 3, + kQintTypes = 1 << 4, + kNonComplexOrQintTypes = kIntegralTypes | kBoolType | kFloatingPointTypes, + }; + enum BackendType { kUninitialized, kSimpleIREval, @@ -71,7 +81,11 @@ class TORCH_API TensorExprKernel { std::vector valueShape(const torch::jit::Value* v); - void promoteInputs(std::vector& inputs); + bool checkTypes(const ScalarType highType, const int typeConstraints); + + void promoteInputs( + std::vector& inputs, + int typeConstraints = kAllTypes); ExprHandle demoteOutput(const ExprHandle& e, const torch::jit::Value* v); @@ -82,7 +96,8 @@ class TORCH_API TensorExprKernel { Tensor* computeOneOperand( const std::string& name, const torch::jit::Value* v, - const std::function& innerExpr); + const std::function& innerExpr, + const int checkParamTypes = kAllTypes); Tensor* computeTwoOperand( const std::string& name, @@ -101,7 +116,8 @@ class TORCH_API TensorExprKernel { const torch::jit::Value* v, const std::function< ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& - innerExpr); + innerExpr, + bool promote_inputs = true); Tensor* computeConditionWithTwoOperand( const std::string& name, @@ -137,6 +153,8 @@ class TORCH_API TensorExprKernel { void bindInput(const torch::jit::Value* input); + Tensor* convertOutputToCorrectStrides(torch::jit::Value* v); + // Captures the information for reduction operation nodes. struct ReductionInfo { std::vector reductionDims; @@ -189,6 +207,9 @@ class TORCH_API TensorExprKernel { int64_t nInputs_ = 0; std::vector kernelArgs_; + std::vector> tensorOutputSizes_; + std::vector> tensorOutputStrides_; + std::vector tensorOutputTensorOptions_; std::vector tensorOutputs_; std::unordered_map tensors_; std::unordered_map scalars_; @@ -198,7 +219,8 @@ class TORCH_API TensorExprKernel { std::vector inputTypes_; std::shared_ptr graph_; Code code_; - bool fallback_{false}; + bool allow_fallback_{false}; + bool use_fallback_{false}; bool hasRandom_{false}; bool hasBroadcast_{false}; std::unordered_map> diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 35929a61266f..f4fd647be3af 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -30,10 +30,16 @@ #include #include +#include #define DEBUG_PRINT 0 using namespace torch::jit::tensorexpr; +C10_DEFINE_bool( + torch_jit_llvm_use_fast_intrinsics, + false, + "Use fast (but slightly less accurate) implementations of tanh and sigmoid"); + DEFINE_TRIGGER(llvm_codegen_created); DEFINE_TRIGGER(llvm_codegen_executed); @@ -42,18 +48,6 @@ namespace jit { namespace tensorexpr { namespace { -bool is_unsigned_integral(const ScalarType& type) { - switch (type) { - case ScalarType::Bool: - case ScalarType::Byte: - return true; - default: - return false; - } - - return false; -} - llvm::CmpInst::Predicate llvm_comparison_predicate( CompareSelectOperation compare_op, const ScalarType& type) { @@ -63,17 +57,17 @@ llvm::CmpInst::Predicate llvm_comparison_predicate( case CompareSelectOperation::kNE: return llvm::ICmpInst::ICMP_NE; case CompareSelectOperation::kGT: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGT - : llvm::ICmpInst::ICMP_SGT; + return is_signed(type) ? llvm::ICmpInst::ICMP_SGT + : llvm::ICmpInst::ICMP_UGT; case CompareSelectOperation::kGE: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGE - : llvm::ICmpInst::ICMP_SGE; + return is_signed(type) ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE; case CompareSelectOperation::kLT: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULT - : llvm::ICmpInst::ICMP_SLT; + return is_signed(type) ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT; case CompareSelectOperation::kLE: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULE - : llvm::ICmpInst::ICMP_SLE; + return is_signed(type) ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE; default: // TODO: change to a proper error report throw std::runtime_error("invalid operator type"); @@ -158,6 +152,7 @@ class LLVMCodeGenImpl : public IRVisitor { #undef IMM_VISIT_DECLARE void visit(const Cast* v) override; + void visit(const BitCast* v) override; void visit(const Var* v) override; void visit(const Ramp* v) override; void visit(const Load* v) override; @@ -495,12 +490,13 @@ void LLVMCodeGenImpl::emitKernel( irb_.SetInsertPoint(bb_); // Maybe expand some of the intrinsics. -#ifdef USE_FAST_CPU_INTRINSICS - LLVMIntrinsicsExpander intrinsics_expander; -#else - GenericIntrinsicsExpander intrinsics_expander; -#endif - stmt = stmt->accept_mutator(&intrinsics_expander); + if (FLAGS_torch_jit_llvm_use_fast_intrinsics) { + LLVMIntrinsicsExpander intrinsics_expander; + stmt = stmt->accept_mutator(&intrinsics_expander); + } else { + GenericIntrinsicsExpander intrinsics_expander; + stmt = stmt->accept_mutator(&intrinsics_expander); + } // Compile the kernel. stmt->accept(this); @@ -518,6 +514,13 @@ void LLVMCodeGenImpl::emitKernel( if (llvm::verifyFunction(*fn_, &llvm::outs())) { throw std::runtime_error("Function verification failed"); } + + // print graph debug info. + std::string fnstr; + llvm::raw_string_ostream FS(fnstr); + fn_->print(FS); + GRAPH_DEBUG("LLVM Function:\n", FS.str(), "\n"); + optimize(*module_); #if DEBUG_PRINT @@ -706,7 +709,8 @@ void LLVMCodeGenImpl::visit(const Max* v) { auto rhs = this->value_; if (v->dtype().is_integral()) { - auto icmp = irb_.CreateICmpSGT(lhs, rhs); + auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSGT(lhs, rhs) + : irb_.CreateICmpUGT(lhs, rhs); value_ = irb_.CreateSelect(icmp, lhs, rhs); return; } @@ -727,7 +731,8 @@ void LLVMCodeGenImpl::visit(const Min* v) { v->rhs()->accept(this); auto rhs = this->value_; if (v->dtype().is_integral()) { - auto icmp = irb_.CreateICmpSLT(lhs, rhs); + auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSLT(lhs, rhs) + : irb_.CreateICmpULT(lhs, rhs); value_ = irb_.CreateSelect(icmp, lhs, rhs); return; } @@ -874,6 +879,25 @@ void LLVMCodeGenImpl::visit(const Cast* v) { } } +void LLVMCodeGenImpl::visit(const BitCast* v) { + v->src_value()->accept(this); + + llvm::Type* dstType = dtypeToLLVM(v->dtype()); + if (v->dtype().lanes() > 1) { + dstType = llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes())); + } + llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype()); + + if (srcType == dstType) { + // do nothing. + return; + } + + TORCH_CHECK(llvm::CastInst::isBitCastable( + srcType->getScalarType(), dstType->getScalarType())); + value_ = irb_.CreateBitOrPointerCast(value_, dstType); +} + void LLVMCodeGenImpl::visit(const Var* v) { if (varToArg_.count(v)) { auto idx = varToArg_.at(v); @@ -1637,7 +1661,7 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { } else if (v->dtype().is_integral() && v->op_type() == kFabs) { // abs is only intrinsic defined for integer inputs in pytorch eager v->params().front()->accept(this); - if (is_unsigned_integral(v->dtype().scalar_type())) { + if (!v->dtype().is_signed()) { return; } // TODO: use llvm.abs intrinsic for LLVM 12 diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 16eb1ec11299..1598a92ac68c 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -154,6 +154,14 @@ class Vectorizer : public IRMutator { }); } + const Expr* mutate(const BitCast* v) override { + std::vector inputs = {v->src_value()}; + return try_vectorize(v, inputs, [&]() { + return BitCast::make( + Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); + }); + } + const Expr* mutate(const Cast* v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { @@ -398,7 +406,11 @@ class DepTracker : public IRVisitor { public: std::vector findUsedTensors(Tensor* tensor) { used_tensors.clear(); - tensor->body()->accept(this); + if (tensor->body()) { + tensor->body()->accept(this); + } else { + tensor->ElementStmt()->accept(this); + } return used_tensors; } @@ -505,6 +517,11 @@ LoopNest::LoopNest(const std::vector& output_tensors) { Stmt* LoopNest::lowerToStmt(Tensor* t) { Stmt* body = t->ElementStmt(); + // If this Tensor has no functional body, it already has its axes expanded. + if (nullptr == t->body()) { + return body; + } + if (t->ndim() == 0 && t->reduce_ndim() == 0) { return body; } @@ -542,8 +559,10 @@ Stmt* LoopNest::lowerToStmt(Tensor* t) { class FunctionInliner : public IRMutator { public: - FunctionInliner(Store* producer) - : buf_(producer->buf()), producer_(producer) { + FunctionInliner(Store* producer, std::unordered_set outputs) + : buf_(producer->buf()), + producer_(producer), + outputs_(std::move(outputs)) { for (auto* i : producer->indices()) { const Var* index_var = dynamic_cast(i); if (index_var == nullptr) { @@ -631,7 +650,9 @@ class FunctionInliner : public IRMutator { // Remove the buffer write from the inlined function. Stmt* mutate(const Store* v) override { - if (v == producer_) { + // If the buf_ is in the outputs set, keep its statement intact. Otherwise, + // remove it. + if (v == producer_ && !outputs_.count(buf_)) { in_producer_ = true; producer_ = dynamic_cast(IRMutator::mutate(v)); TORCH_INTERNAL_ASSERT(producer_ != nullptr); @@ -696,6 +717,7 @@ class FunctionInliner : public IRMutator { // In the producer's scope - we need to bind any calls to rand(). bool in_producer_ = false; std::unordered_map> random_bindings_; + std::unordered_set outputs_; }; bool LoopNest::computeInline(Stmt* s) { @@ -707,11 +729,6 @@ bool LoopNest::computeInline(Stmt* s) { } bool LoopNest::computeInline(const Buf* b) { - if (output_bufs_.count(b)) { - // Cannot inline producers of output Tensors - return false; - } - // Find producers. Store* relevant_store{nullptr}; auto stores = NodeFinder::find(root_stmt_); @@ -731,7 +748,7 @@ bool LoopNest::computeInline(const Buf* b) { } TORCH_INTERNAL_ASSERT(relevant_store); - FunctionInliner inliner(relevant_store); + FunctionInliner inliner(relevant_store, output_bufs_); root_stmt_ = root_stmt_->accept_mutator(&inliner); // No longer computing this intermediate tensor, so don't alloc it. @@ -745,6 +762,7 @@ void LoopNest::inlineIntermediateBufs() { // erased from the set 'intermediate_bufs_' in that function. std::unordered_set bufs_to_inline( intermediate_bufs_.begin(), intermediate_bufs_.end()); + bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end()); for (auto b : bufs_to_inline) { computeInline(b); } @@ -867,6 +885,63 @@ Stmt* LoopNest::insertAllocFree(Stmt* stmt) { return b; } +class StmtDeleter : public IRMutator { + public: + StmtDeleter(const std::unordered_set& targets) + : targets_(targets) {} + + private: + Stmt* mutate(const Block* v) override { + std::vector stmts; + + for (auto* s : v->stmts()) { + if (targets_.count(s) == 0) { + Stmt* ns = s->accept_mutator(this); + if (ns) { + stmts.push_back(Stmt::clone(ns)); + } + } + } + + return Block::make(stmts); + } + + const std::unordered_set& targets_; +}; + +void LoopNest::eliminateDeadStores() { + using namespace analysis; + MemDependencyChecker checker(getInputBufs(), getOutputBufs()); + root_stmt_->accept(&checker); + + std::unordered_set deadStores; + std::vector> outputAccesses; + for (auto* o : getOutputBufs()) { + outputAccesses.push_back(checker.output(o)); + } + + for (auto& info : checker.getHistory()) { + if (!info->isWrite()) { + continue; + } + bool found = false; + + for (auto& output : outputAccesses) { + if (checker.dependsIndirectly(output, info)) { + found = true; + break; + } + } + + if (!found) { + deadStores.insert(info->stmt()); + } + } + + StmtDeleter deleter(deadStores); + root_stmt_ = root_stmt_->accept_mutator(&deleter); +} + void LoopNest::prepareForCodegen() { // Expand reduction ops. ReductionExpander reduceExpander; diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index af0f28884f5a..540b7fa889a9 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -107,6 +107,7 @@ class TORCH_API LoopNest { For* f, const std::unordered_map& map); + void eliminateDeadStores(); void prepareForCodegen(); // Find the inner-most loops and vectorize them. Currently, this only works diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 4afc1ffeefb5..d12f6999c8d5 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -86,7 +86,7 @@ Tensor* Compute( return new Tensor(func_name, dims, args_nodes, body); } -Stmt* Tensor::ElementStmt() { +Stmt* Tensor::ElementStmt() const { std::vector indices; for (size_t i = 0; i < buf_->ndim(); i++) { indices.push_back(args_[i]); diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index d37f14c3a606..e5e399db348b 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -12,7 +12,7 @@ namespace torch { namespace jit { namespace tensorexpr { -class Tensor : KernelScopedObject { +class TORCH_API Tensor : KernelScopedObject { public: Tensor( const std::string& name, @@ -27,7 +27,7 @@ class Tensor : KernelScopedObject { : buf_(buf), args_(args), body_(body) {} Tensor( - Buf* buf, + const Buf* buf, const std::vector& args, const std::vector& reduce_dims, const std::vector& reduce_args, @@ -38,6 +38,8 @@ class Tensor : KernelScopedObject { reduce_dims_(reduce_dims), reduce_args_(reduce_args) {} + virtual ~Tensor() {} + // Wrappers over accessors to fields of the underlying function const Expr* body() const { return body_; @@ -94,7 +96,7 @@ class Tensor : KernelScopedObject { const Expr* initializer() const { return initializer_; } - Stmt* ElementStmt(); + virtual Stmt* ElementStmt() const; template inline ExprHandle operator()(const Ts&... ts); @@ -113,6 +115,24 @@ class Tensor : KernelScopedObject { const Expr* initializer_{nullptr}; }; +class TORCH_API CompoundTensor : public Tensor { + public: + CompoundTensor( + const Buf* buf, + const std::vector& args, + Stmt* stmt) + : Tensor(buf, args, {}, {}, nullptr), stmt_(stmt) {} + + virtual ~CompoundTensor() {} + + Stmt* ElementStmt() const override { + return stmt_; + } + + private: + Stmt* stmt_; +}; + class Placeholder { public: Placeholder(const BufHandle& data) : data_(data.node()) { @@ -306,7 +326,7 @@ class FunctionCall : public CallNode { } FunctionCall(Tensor* tensor, const std::vector& params) - : BaseClass(tensor->body()->dtype(), kFunctionCall, params), + : BaseClass(tensor->buf()->dtype(), kFunctionCall, params), tensor_(tensor) {} private: diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index f7aa96be4c45..ae9bdcf1986c 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -9,33 +9,26 @@ namespace torch { namespace jit { namespace tensorexpr { -bool is_integral(const ScalarType& type) { - switch (type) { - case ScalarType::Bool: - case ScalarType::Byte: - case ScalarType::Char: - case ScalarType::Short: - case ScalarType::Int: - case ScalarType::Long: - return true; - default: - return false; - } +static bool is_c10_type(const ScalarType& type) { + return type < ScalarType::Undefined; +} - return false; +bool is_integral(const ScalarType& type) { + return is_c10_type(type) + ? c10::isIntegralType(static_cast(type), true) + : false; } bool is_floating_point(const ScalarType& type) { - switch (type) { - case ScalarType::Half: - case ScalarType::Float: - case ScalarType::Double: - return true; - default: - return false; - } + return is_c10_type(type) + ? c10::isFloatingType(static_cast(type)) + : false; +} - return false; +bool is_signed(const ScalarType& type) { + return is_c10_type(type) + ? c10::isSignedType(static_cast(type)) + : false; } Dtype Dtype::scalar_dtype() const { diff --git a/torch/csrc/jit/tensorexpr/types.h b/torch/csrc/jit/tensorexpr/types.h index 3e8ec36ec2f3..29ccf06ef035 100644 --- a/torch/csrc/jit/tensorexpr/types.h +++ b/torch/csrc/jit/tensorexpr/types.h @@ -37,6 +37,7 @@ TORCH_API std::ostream& operator<<( TORCH_API bool is_integral(const ScalarType& type); TORCH_API bool is_floating_point(const ScalarType& type); +TORCH_API bool is_signed(const ScalarType& type); // Data types for scalar and vector elements. class TORCH_API Dtype { @@ -75,6 +76,9 @@ class TORCH_API Dtype { bool is_floating_point() const { return tensorexpr::is_floating_point(scalar_type_); } + bool is_signed() const { + return tensorexpr::is_signed(scalar_type_); + } Dtype cloneWithScalarType(ScalarType nt) const { return Dtype(nt, lanes_); diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 950e7d9fb82d..c7fdf844945e 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -39,6 +39,7 @@ static std::unordered_map type_map = { {"std::string", ParameterType::STRING}, {"Dimname", ParameterType::DIMNAME}, {"DimnameList", ParameterType::DIMNAME_LIST}, + {"ScalarList", ParameterType::SCALAR_LIST}, }; // Default arg name translations for compatibility with NumPy. @@ -348,13 +349,28 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* ove return false; } -bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { +bool is_scalar_list(PyObject* obj) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (size_t idx = 0; idx < size; idx++) { + PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); + if (!THPUtils_checkScalar(iobj)) { + return false; + } + } + return true; +} + +bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { + auto tuple = six::isTuple(obj); + if (!(tuple || PyList_Check(obj))) { + return false; + } + auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); + for (long idx = 0; idx < size; idx++) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) { if (throw_error) { @@ -453,6 +469,9 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded return THPStream_Check(obj); case ParameterType::STRING: return THPUtils_checkString(obj); default: throw std::runtime_error("unknown parameter type"); + case ParameterType::SCALAR_LIST: { + return is_scalar_list(obj); + } } } @@ -478,6 +497,7 @@ std::string FunctionParameter::type_name() const { case ParameterType::STRING: return "str"; case ParameterType::DIMNAME: return "name"; case ParameterType::DIMNAME_LIST: return "tuple of names"; + case ParameterType::SCALAR_LIST: return "tuple of Scalars"; default: throw std::runtime_error("unknown parameter type"); } } @@ -1055,24 +1075,28 @@ at::Scalar PythonArgs::scalar_slow(int i) { signature.params[i].name, idx, var, jit::NumberType::get()); } + return scalar_slow(args[i]); +} + +at::Scalar PythonArgs::scalar_slow(PyObject* arg) { // Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently // handle most NumPy scalar types except np.float64. - if (THPVariable_Check(args[i])) { - return ((THPVariable*)args[i])->cdata.item(); + if (THPVariable_Check(arg)) { + return ((THPVariable*)arg)->cdata.item(); } - if (THPUtils_checkLong(args[i])) { - return at::Scalar(static_cast(THPUtils_unpackLong(args[i]))); + if (THPUtils_checkLong(arg)) { + return at::Scalar(static_cast(THPUtils_unpackLong(arg))); } - if (PyBool_Check(args[i])) { - return at::Scalar(THPUtils_unpackBool(args[i])); + if (PyBool_Check(arg)) { + return at::Scalar(THPUtils_unpackBool(arg)); } - if (PyComplex_Check(args[i])) { - return at::Scalar(THPUtils_unpackComplexDouble(args[i])); + if (PyComplex_Check(arg)) { + return at::Scalar(THPUtils_unpackComplexDouble(arg)); } - return at::Scalar(THPUtils_unpackDouble(args[i])); + return at::Scalar(THPUtils_unpackDouble(arg)); } } // namespace torch diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index b0b81a9517da..ccf3ba6b42c4 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -80,7 +80,7 @@ namespace torch { enum class ParameterType { TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR, BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING, - DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST + DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST }; struct FunctionParameter; @@ -158,6 +158,7 @@ struct PythonArgs { inline c10::optional optionalTensor(int i); inline at::Scalar scalar(int i); inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar); + inline std::vector scalarlist(int i); inline std::vector tensorlist(int i); template inline std::array tensorlist_n(int i); @@ -206,6 +207,7 @@ struct PythonArgs { private: at::Tensor tensor_slow(int i); at::Scalar scalar_slow(int i); + at::Scalar scalar_slow(PyObject* arg); }; struct FunctionParameter { @@ -287,6 +289,19 @@ inline at::Scalar PythonArgs::scalar(int i) { return scalar_slow(i); } +inline std::vector PythonArgs::scalarlist(int i) { + if (!args[i]) return std::vector(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + std::vector res(size); + for (int idx = 0; idx < size; idx++) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + res[idx] = scalar_slow(obj); + } + return res; +} + inline at::Scalar PythonArgs::scalarWithDefault(int i, at::Scalar default_scalar) { if (!args[i]) return default_scalar; return scalar_slow(i); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 0850b535fe30..a4068ac6d7f3 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -16,7 +16,7 @@ import threading from typing import List, Optional, Tuple, Union from ._utils import _get_device_index, _dummy_type -from .streams import Stream, Event +from .streams import Stream, Event, _Graph from .. import device as _device import torch._C diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 14345baf6abd..9c9c30a7ff29 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -8,6 +8,7 @@ # Define dummy base classes torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase') torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase') + torch._C.__dict__['_CudaGraphBase'] = _dummy_type('_CudaGraphBase') class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. @@ -20,7 +21,7 @@ class Stream(torch._C._CudaStreamBase): device(torch.device or int, optional): a device on which to allocate the stream. If :attr:`device` is ``None`` (default) or a negative integer, this will use the current device. - priority(int, optional): priority of the stream. Can be either + priority(int, optional): priority of the stream. Can be either -1 (high priority) or 0 (low priority). By default, streams have priority 0. @@ -201,3 +202,5 @@ def __repr__(self): return ''.format(self._as_parameter_.value) else: return '' + +_Graph = torch._C._CudaGraphBase diff --git a/torch/distributed/_pipeline/sync/balance/__init__.py b/torch/distributed/_pipeline/sync/balance/__init__.py index 15aa53bc1a2c..8c6da586657f 100644 --- a/torch/distributed/_pipeline/sync/balance/__init__.py +++ b/torch/distributed/_pipeline/sync/balance/__init__.py @@ -18,7 +18,7 @@ pipe = Pipe(model, balance, chunks=8) """ -from typing import List, Tuple, Union +from typing import List, Union, Sequence import torch from torch import Tensor @@ -32,7 +32,7 @@ Device = Union[torch.device, int, str] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] diff --git a/torch/distributed/_pipeline/sync/balance/profile.py b/torch/distributed/_pipeline/sync/balance/profile.py index 737dda60f6fa..382da988e808 100644 --- a/torch/distributed/_pipeline/sync/balance/profile.py +++ b/torch/distributed/_pipeline/sync/balance/profile.py @@ -7,7 +7,7 @@ """Per-layer profilers.""" import copy import time -from typing import Generator, List, Tuple, Union +from typing import Generator, List, Union, Sequence import torch from torch import Tensor @@ -20,7 +20,7 @@ Device = Union[torch.device, int, str] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] diff --git a/torch/distributed/_pipeline/sync/checkpoint.py b/torch/distributed/_pipeline/sync/checkpoint.py index bad5eec19469..3f9240793183 100644 --- a/torch/distributed/_pipeline/sync/checkpoint.py +++ b/torch/distributed/_pipeline/sync/checkpoint.py @@ -27,7 +27,16 @@ from collections import deque from contextlib import contextmanager import threading -from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Deque, + Generator, + List, + Optional, + Union, + Sequence, + Tuple +) import torch from torch import Tensor @@ -40,7 +49,7 @@ __all__ = ["is_checkpointing", "is_recomputing"] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] # Types for shared memory between Checkpoint and Recompute. diff --git a/torch/distributed/_pipeline/sync/copy.py b/torch/distributed/_pipeline/sync/copy.py index 3d330f59eeee..07e71a87ce08 100644 --- a/torch/distributed/_pipeline/sync/copy.py +++ b/torch/distributed/_pipeline/sync/copy.py @@ -8,7 +8,7 @@ and computation on the same GPU. """ from collections import deque -from typing import Deque, List, Optional, Tuple +from typing import Deque, List, Optional, Tuple, Sequence import torch from torch import Tensor @@ -18,7 +18,7 @@ __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] # Common interface between :class:`Copy` and :class:`Wait`. diff --git a/torch/distributed/_pipeline/sync/microbatch.py b/torch/distributed/_pipeline/sync/microbatch.py index d38cb6d3b85c..fc4daf7a9b42 100644 --- a/torch/distributed/_pipeline/sync/microbatch.py +++ b/torch/distributed/_pipeline/sync/microbatch.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. """Manipulation of micro-batches.""" import typing -from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast +from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence import torch from torch import Tensor @@ -15,7 +15,7 @@ __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] Function = Callable[[TensorOrTensors], TensorOrTensors] @@ -110,7 +110,7 @@ def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None: def _setitem_by_index(self, index: int, value: Tensor) -> None: if not self.atomic: i = index - self.value = self.value[:i] + (value,) + self.value[i + 1 :] + self.value = self.value[:i] + (value,) + self.value[i + 1 :] # type: ignore return if index != 0: @@ -139,9 +139,10 @@ def check(input: TensorOrTensors) -> None: TypeError: input is not a tensor or tensors. """ - if isinstance(input, tuple): + if isinstance(input, Sequence): for x in input: - check(x) + if not isinstance(x, Tensor): + raise TypeError(f"expected Tensor, but got {input.__class__.__name__}") return if not isinstance(input, Tensor): diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index 92a3c301cc39..82db93060d91 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -6,10 +6,11 @@ # LICENSE file in the root directory of this source tree. """The Pipe interface.""" from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence import torch from torch import Tensor, nn +from torch.distributed.rpc import RRef import torch.autograd import torch.cuda @@ -26,7 +27,7 @@ Device = Union[torch.device, int, str] Devices = Union[Iterable[Device], List[Device]] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] if TYPE_CHECKING: @@ -305,18 +306,18 @@ def _ensure_copy_streams(self) -> List[List[AbstractStream]]: return self._copy_streams - def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ignore """:class:`Pipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a - :class:`~torch.Tensor` or a tuple of tensors. This restriction is + :class:`~torch.Tensor` or a sequence of tensors. This restriction is applied at partition boundaries too. Args: - input (torch.Tensor or tensors): input mini-batch + input (torch.Tensor or Sequence[torch.Tensor]): input mini-batch Returns: - tensor or tensors: output mini-batch + :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch Raises: TypeError: input is not a tensor or tensors. @@ -326,7 +327,7 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore if not self.devices: # Empty sequential module is not illegal. - return input + return RRef(input) # Divide a mini-batch into micro-batches. batches = microbatch.scatter(input, self.chunks) @@ -336,4 +337,4 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore # Merge the micro-batches into one mini-batch. output = microbatch.gather(batches) - return output + return RRef(output) diff --git a/torch/distributed/_pipeline/sync/pipeline.py b/torch/distributed/_pipeline/sync/pipeline.py index 86c8dfddebeb..72c04c6f28d0 100644 --- a/torch/distributed/_pipeline/sync/pipeline.py +++ b/torch/distributed/_pipeline/sync/pipeline.py @@ -7,7 +7,7 @@ """The pipeline parallelism of Pipe.""" from queue import Queue from types import TracebackType -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence import torch from torch import Tensor, nn @@ -25,7 +25,7 @@ __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] diff --git a/torch/distributed/_pipeline/sync/skip/skippable.py b/torch/distributed/_pipeline/sync/skip/skippable.py index 9bb258382b9b..e0b0dae584a2 100644 --- a/torch/distributed/_pipeline/sync/skip/skippable.py +++ b/torch/distributed/_pipeline/sync/skip/skippable.py @@ -17,6 +17,7 @@ List, Optional, Set, + Sequence, Tuple, Type, TypeVar, @@ -33,7 +34,7 @@ __all__ = ["skippable", "stash", "pop", "verify_skippables"] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] StashPop = Union["stash", "pop"] diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index e1d475a34425..751621189706 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -47,6 +47,8 @@ def __init__( random_seed=0, ): self.process_group = process_group + # The low rank for matrix approximation. + # Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. self.matrix_approximation_rank = matrix_approximation_rank # Error feedback is usually crucial for both for convergence and generalization, # because PowerSGD is a biased compressor, @@ -97,8 +99,6 @@ def powerSGD_hook( bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode at this time, only exactly one tensor is stored in this bucket. - matrix_approximation_rank (int): The low rank for matrix approximation. - Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. Returns: Future handler of the communication, which updates the gradients in place. @@ -126,6 +126,7 @@ def powerSGD_hook( # Incorporate the error from the previous state into the gradients. bucket_index = bucket.get_index() + input_tensor_cp = None if state.use_error_feedback: # The buckets can be rebuilt during training. # In this case, the error tensor shape will not be aligned with the input tensor, @@ -162,11 +163,17 @@ def create_low_rank_tensor(fill_random_values, rng): # only fork on CPU and then move the generated tensor to the CUDA device. torch.manual_seed(rng.randint(1_000_000_000)) return torch.randn( - square_side_length, state.matrix_approximation_rank, device="cpu" + square_side_length, + state.matrix_approximation_rank, + device="cpu", + dtype=input_tensor.dtype, ).to(device) else: return torch.empty( - square_side_length, state.matrix_approximation_rank, device=device + square_side_length, + state.matrix_approximation_rank, + device=device, + dtype=input_tensor.dtype, ) p = create_low_rank_tensor(fill_random_values=False, rng=state.rng) @@ -183,9 +190,7 @@ def compute_q(fut): torch.matmul(matrix.t(), p, out=q) return [ - dist.all_reduce(q, group=group_to_use, async_op=True) - .get_future() - .value()[0] + dist.all_reduce(q, group=group_to_use, async_op=True).get_future().wait()[0] ] def decompress(fut): @@ -195,6 +200,7 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. state.error_dict[bucket_index] = input_tensor_cp - input_tensor + assert not torch.any(torch.isnan(state.error_dict[bucket_index])) ret = input_tensor.resize_(total_length) return [ret] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 13a950024af9..387da70403b0 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -147,7 +147,8 @@ def __getattribute__(self, key): class group(object): - WORLD = object() + # Points to the default PG once initialized. + WORLD: Optional[ProcessGroup] = None class GroupMember(object): @@ -166,7 +167,6 @@ class GroupMember(object): _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} # Default process group state -_default_pg: Optional[ProcessGroup] = None _default_pg_init_method = None # Process group count for default naming @@ -177,7 +177,7 @@ def _rank_not_in_group(group: ProcessGroup): """ Helper that checks if the current process's rank is not in a given group. """ - if group == GroupMember.WORLD: + if group is None: return False return group == GroupMember.NON_GROUP_MEMBER @@ -214,23 +214,12 @@ def _get_global_rank(group, group_rank): raise RuntimeError("The group rank is not part of the group") -def _check_default_pg() -> ProcessGroup: - """ - Helper that checks if the default ProcessGroup has been initialized, with - assertion. - """ - if _default_pg is not None: - return _default_pg - else: - raise RuntimeError("Default process group is not initialized") - - def _get_group_size(group): """ Helper that gets a given group's world size. """ - if group is GroupMember.WORLD: - default_pg = _check_default_pg() + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() return default_pg.size() if group not in _pg_group_ranks: raise RuntimeError("The given group does not exist") @@ -306,7 +295,7 @@ def is_initialized(): """ Checking if the default process group has been initialized """ - return _default_pg is not None + return GroupMember.WORLD is not None def _get_default_group(): @@ -316,7 +305,7 @@ def _get_default_group(): if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " "please make sure to call init_process_group.") - return _default_pg + return GroupMember.WORLD def _get_default_store(): @@ -326,12 +315,15 @@ def _get_default_store(): if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " "please make sure to call init_process_group.") - default_pg = _check_default_pg() + default_pg = _get_default_group() _, default_store = _pg_map[default_pg] return default_store +def _update_default_pg(pg): + GroupMember.WORLD = group.WORLD = pg -def get_backend(group=group.WORLD): + +def get_backend(group=None): """ Returns the backend of the given process group. @@ -344,8 +336,8 @@ def get_backend(group=group.WORLD): The backend of the given process group as a lower case string. """ - if group == GroupMember.WORLD: - pg = _check_default_pg() + if group is None: + pg = _get_default_group() else: pg = group if _rank_not_in_group(pg): @@ -421,14 +413,13 @@ def init_process_group(backend, """ global _pg_group_ranks global _backend - global _default_pg global _default_pg_init_method if not isinstance(timeout, timedelta): raise RuntimeError("Expected timeout argument to be of type" "datetime.timedelta") - if _default_pg is not None: + if GroupMember.WORLD is not None: raise RuntimeError("trying to initialize the default process group " "twice!") @@ -450,14 +441,14 @@ def init_process_group(backend, "are ignored since they are assigned by the " "MPI runtime.".format(world_size, rank)) - _default_pg = _new_process_group_helper( + _update_default_pg(_new_process_group_helper( -1, -1, [], Backend.MPI, None, group_name=group_name, - timeout=timeout) + timeout=timeout)) else: # backward compatible API if store is None: @@ -467,17 +458,17 @@ def init_process_group(backend, store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) - _default_pg = _new_process_group_helper( + _update_default_pg(_new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, - timeout=timeout) + timeout=timeout)) - _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())} - _backend = _pg_map[_default_pg][0] + _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore + _backend = _pg_map[GroupMember.WORLD][0] # type: ignore _default_pg_init_method = init_method # barrier at the end to ensure that once we return from this method, all @@ -537,7 +528,7 @@ def _new_process_group_helper(world_size, # If this is a subgroup (which means group_ranks is specified), # we check if the current process is a member of the new group. if not is_default_group: - global_rank = _check_default_pg().rank() + global_rank = _get_default_group().rank() if global_rank not in group_ranks: return GroupMember.NON_GROUP_MEMBER @@ -576,7 +567,7 @@ def _new_process_group_helper(world_size, return pg -def destroy_process_group(group=group.WORLD): +def destroy_process_group(group=None): """ Destroy a given process group, and deinitialize the distributed package @@ -589,15 +580,14 @@ def destroy_process_group(group=group.WORLD): global _pg_map global _pg_names global _pg_group_ranks - global _default_pg global _default_pg_init_method global _group_count if group == GroupMember.NON_GROUP_MEMBER: return - if group == GroupMember.WORLD: - pg = _default_pg + if group is None: + pg = GroupMember.WORLD else: pg = group @@ -605,8 +595,8 @@ def destroy_process_group(group=group.WORLD): if _pg_map.get(pg, None) is None: raise RuntimeError("Invalid process group specified") - if group == GroupMember.WORLD: - _default_pg = None + if group is None or group == GroupMember.WORLD: + _update_default_pg(None) _default_pg_init_method = None _pg_map.clear() _pg_names.clear() @@ -627,7 +617,7 @@ def destroy_process_group(group=group.WORLD): del _pg_group_ranks[pg] -def get_rank(group=group.WORLD): +def get_rank(group=None): """ Returns the rank of current process group @@ -636,7 +626,8 @@ def get_rank(group=group.WORLD): ``world_size``. Arguments: - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: The rank of the process group @@ -646,19 +637,20 @@ def get_rank(group=group.WORLD): if _rank_not_in_group(group): return -1 - default_pg = _check_default_pg() - if group == GroupMember.WORLD: + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: return default_pg.rank() return _get_group_rank(group, default_pg.rank()) -def get_world_size(group=group.WORLD): +def get_world_size(group=None): """ Returns the number of processes in the current process group Arguments: - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: The world size of the process group @@ -673,7 +665,7 @@ def get_world_size(group=group.WORLD): def isend(tensor, dst, - group=group.WORLD, + group=None, tag=0): """ Sends a tensor asynchronously. @@ -681,7 +673,8 @@ def isend(tensor, Arguments: tensor (Tensor): Tensor to send. dst (int): Destination rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with remote recv Returns: @@ -693,8 +686,8 @@ def isend(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() return default_pg.send([tensor], dst, tag) else: group_dst_rank = _get_group_rank(group, dst) @@ -703,7 +696,7 @@ def isend(tensor, def irecv(tensor, src, - group=group.WORLD, + group=None, tag=0): """ Receives a tensor asynchronously. @@ -711,7 +704,8 @@ def irecv(tensor, Arguments: tensor (Tensor): Tensor to fill with received data. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match recv with remote send Returns: @@ -723,8 +717,8 @@ def irecv(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() return default_pg.recv([tensor], src, tag) else: group_src_rank = _get_group_rank(group, src) @@ -733,7 +727,7 @@ def irecv(tensor, def send(tensor, dst, - group=group.WORLD, + group=None, tag=0): """ Sends a tensor synchronously. @@ -741,7 +735,8 @@ def send(tensor, Arguments: tensor (Tensor): Tensor to send. dst (int): Destination rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with remote recv """ @@ -749,8 +744,8 @@ def send(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() default_pg.send([tensor], dst, tag).wait() else: group_dst_rank = _get_group_rank(group, dst) @@ -759,7 +754,7 @@ def send(tensor, def recv(tensor, src=None, - group=group.WORLD, + group=None, tag=0): """ Receives a tensor synchronously. @@ -768,7 +763,8 @@ def recv(tensor, tensor (Tensor): Tensor to fill with received data. src (int, optional): Source rank. Will receive from any process if unspecified. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match recv with remote send Returns: @@ -780,8 +776,8 @@ def recv(tensor, if _rank_not_in_group(group): return -1 - if group == GroupMember.WORLD: - pg = _check_default_pg() + if group is None: + pg = _get_default_group() else: pg = group @@ -789,12 +785,12 @@ def recv(tensor, work = pg.recv_anysource([tensor], tag) work.wait() src_rank = work._source_rank() - if group == GroupMember.WORLD: + if group is None or group is GroupMember.WORLD: return src_rank else: return _get_global_rank(pg, src_rank) else: - if group == GroupMember.WORLD: + if group is None or group is GroupMember.WORLD: pg.recv([tensor], src, tag).wait() else: group_src_rank = _get_group_rank(pg, src) @@ -816,17 +812,18 @@ class P2POp(object): ``torch.distributed.irecv``. tensor (Tensor): Tensor to send or receive. peer (int): Destination or source rank. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with recv. """ - def __init__(self, op, tensor, peer, group=group.WORLD, tag=0): + def __init__(self, op, tensor, peer, group=None, tag=0): self.op = op self.tensor = tensor self.peer = peer self.group = group self.tag = tag - def __new__(cls, op, tensor, peer, group=group.WORLD, tag=0): + def __new__(cls, op, tensor, peer, group=None, tag=0): _check_op(op) _check_single_tensor(tensor, "tensor") return object.__new__(cls) @@ -896,7 +893,7 @@ def batch_isend_irecv(p2p_op_list): def broadcast_multigpu(tensor_list, src, - group=group.WORLD, + group=None, async_op=False, src_tensor=0): """ @@ -920,7 +917,8 @@ def broadcast_multigpu(tensor_list, for all the distributed processes calling this function. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op src_tensor (int, optional): Source tensor rank within ``tensor_list`` @@ -936,8 +934,8 @@ def broadcast_multigpu(tensor_list, opts.rootRank = src opts.rootTensor = src_tensor - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.broadcast(tensor_list, opts) else: group_src_rank = _get_group_rank(group, src) @@ -951,7 +949,7 @@ def broadcast_multigpu(tensor_list, def broadcast(tensor, src, - group=group.WORLD, + group=None, async_op=False): """ Broadcasts the tensor to the whole group. @@ -963,7 +961,8 @@ def broadcast(tensor, tensor (Tensor): Data to be sent if ``src`` is the rank of current process, and tensor to be used to save received data otherwise. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -979,8 +978,8 @@ def broadcast(tensor, opts.rootRank = src opts.rootTensor = 0 - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.broadcast([tensor], opts) else: group_src_rank = _get_group_rank(group, src) @@ -994,7 +993,7 @@ def broadcast(tensor, def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): r""" Reduces the tensor data across all machines in such a way that all get @@ -1020,7 +1019,8 @@ def all_reduce_multigpu(tensor_list, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1035,8 +1035,8 @@ def all_reduce_multigpu(tensor_list, opts = AllreduceOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allreduce(tensor_list, opts) else: work = group.allreduce(tensor_list, opts) @@ -1049,7 +1049,7 @@ def all_reduce_multigpu(tensor_list, def all_reduce(tensor, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces the tensor data across all machines in such a way that all get @@ -1065,7 +1065,8 @@ def all_reduce(tensor, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1107,8 +1108,8 @@ def all_reduce(tensor, opts = AllreduceOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allreduce([tensor], opts) else: work = group.allreduce([tensor], opts) @@ -1121,7 +1122,7 @@ def all_reduce(tensor, def all_reduce_coalesced(tensors, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ WARNING: at this time individual shape checking is not implemented across nodes. @@ -1146,7 +1147,8 @@ def all_reduce_coalesced(tensors, op (Optional[ReduceOp]): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (Optional[ProcessGroup]): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (Optional[bool]): Whether this op should be an async op. Returns: @@ -1165,8 +1167,8 @@ def all_reduce_coalesced(tensors, opts = AllreduceCoalescedOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allreduce_coalesced(tensors, opts) else: work = group.allreduce_coalesced(tensors, opts) @@ -1180,7 +1182,7 @@ def all_reduce_coalesced(tensors, def reduce_multigpu(tensor_list, dst, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False, dst_tensor=0): """ @@ -1202,7 +1204,8 @@ def reduce_multigpu(tensor_list, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op dst_tensor (int, optional): Destination tensor rank within ``tensor_list`` @@ -1220,8 +1223,8 @@ def reduce_multigpu(tensor_list, opts.rootRank = dst opts.rootTensor = dst_tensor - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.reduce(tensor_list, opts) else: group_dst_rank = _get_group_rank(group, dst) @@ -1237,7 +1240,7 @@ def reduce_multigpu(tensor_list, def reduce(tensor, dst, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces the tensor data across all machines. @@ -1251,7 +1254,8 @@ def reduce(tensor, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1267,8 +1271,8 @@ def reduce(tensor, opts.reduceOp = op opts.rootRank = dst - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.reduce([tensor], opts) else: group_dst_rank = _get_group_rank(group, dst) @@ -1283,7 +1287,7 @@ def reduce(tensor, def all_gather_multigpu(output_tensor_lists, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Gathers tensors from the whole group in a list. @@ -1318,7 +1322,8 @@ def all_gather_multigpu(output_tensor_lists, Note that ``len(input_tensor_list)`` needs to be the same for all the distributed processes calling this function. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1332,8 +1337,8 @@ def all_gather_multigpu(output_tensor_lists, output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allgather(output_tensor_lists, input_tensor_list) else: work = group.allgather(output_tensor_lists, input_tensor_list) @@ -1358,7 +1363,7 @@ def _tensor_to_object(tensor, tensor_size): return out -def all_gather_object(object_list, obj, group=group.WORLD): +def all_gather_object(object_list, obj, group=None): """ Gathers picklable objects from the whole group into a list. Similar to :func:`all_gather`, but Python objects can be passed in. Note that the object @@ -1393,6 +1398,16 @@ def all_gather_object(object_list, obj, group=group.WORLD): known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] """ if _rank_not_in_group(group): return @@ -1417,7 +1432,7 @@ def all_gather_object(object_list, obj, group=group.WORLD): ] # Allgather tensor sizes all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) + max_object_size = int(max(object_size_list).item()) # type: ignore # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) coalesced_output_tensor = torch.empty( @@ -1436,7 +1451,7 @@ def all_gather_object(object_list, obj, group=group.WORLD): object_list[i] = _tensor_to_object(tensor, tensor_size) -def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): +def gather_object(obj, object_gather_list=None, dst=0, group=None): """ Gathers picklable objects from the whole group in a single process. Similar to :func:`gather`, but Python objects can be passed in. Note that the @@ -1467,6 +1482,21 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + gather_objects[dist.get_rank()], + output if dist.get_rank() == 0 else None, + dst=0 + ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] """ if _rank_not_in_group(group): return @@ -1493,7 +1523,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): # gather, since each rank needs to broadcast a tensor of the same (maximal) # size. all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) + max_object_size = int(max(object_size_list).item()) # type: ignore # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) # Avoid populating output tensors if the result won't be gathered on this rank. @@ -1521,7 +1551,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): object_gather_list[i] = _tensor_to_object(tensor, tensor_size) -def broadcast_object_list(object_list, src, group=group.WORLD): +def broadcast_object_list(object_list, src, group=None): """ Broadcasts picklable objects in ``object_list`` to the whole group. Similar to :func:`broadcast`, but Python objects can be passed in. @@ -1556,6 +1586,18 @@ def broadcast_object_list(object_list, src, group=group.WORLD): is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> objects = [None, None, None] + >>> dist.broadcast_object_list(objects, src=0) + >>> broadcast_objects + ['foo', 12, {1: 2}] """ if _rank_not_in_group(group): return @@ -1634,6 +1676,21 @@ def scatter_object_list( is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] """ if _rank_not_in_group(group): return @@ -1687,7 +1744,7 @@ def scatter_object_list( def all_gather(tensor_list, tensor, - group=group.WORLD, + group=None, async_op=False): """ Gathers tensors from the whole group in a list. @@ -1698,7 +1755,8 @@ def all_gather(tensor_list, tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. tensor (Tensor): Tensor to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1743,8 +1801,8 @@ def all_gather(tensor_list, tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allgather([tensor_list], [tensor]) else: work = group.allgather([tensor_list], [tensor]) @@ -1756,7 +1814,7 @@ def all_gather(tensor_list, def all_gather_coalesced(output_tensor_lists, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Gathers input tensors from the whole group in a list in a coalesced manner. @@ -1768,7 +1826,8 @@ def all_gather_coalesced(output_tensor_lists, correctly-sized tensors to be used for output of the collective. input_tensor_list (list[Tensor]): Tensors to be broadcast from current process. At least one tensor has to be non empty. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1814,8 +1873,8 @@ def all_gather_coalesced(output_tensor_lists, output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allgather_coalesced( output_tensor_lists, input_tensor_list) else: @@ -1842,7 +1901,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): def gather(tensor, gather_list=None, dst=0, - group=group.WORLD, + group=None, async_op=False): """ Gathers a list of tensors in a single process. @@ -1853,7 +1912,8 @@ def gather(tensor, tensors to use for gathered data (default is None, must be specified on the destination rank) dst (int, optional): Destination rank (default is 0) - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1880,8 +1940,8 @@ def gather(tensor, opts = GatherOptions() opts.rootRank = dst - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.gather(output_tensors, input_tensors, opts) else: group_dst_rank = _get_group_rank(group, dst) @@ -1897,7 +1957,7 @@ def gather(tensor, def scatter(tensor, scatter_list=None, src=0, - group=group.WORLD, + group=None, async_op=False): """ Scatters a list of tensors to all processes in a group. @@ -1910,7 +1970,8 @@ def scatter(tensor, scatter_list (list[Tensor]): List of tensors to scatter (default is None, must be specified on the source rank) src (int): Source rank (default is 0) - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1946,8 +2007,8 @@ def scatter(tensor, opts = ScatterOptions() opts.rootRank = src - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.scatter(output_tensors, input_tensors, opts) else: group_src_rank = _get_group_rank(group, src) @@ -1963,7 +2024,7 @@ def scatter(tensor, def reduce_scatter_multigpu(output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduce and scatter a list of tensors to the whole group. Only nccl backend @@ -1997,7 +2058,8 @@ def reduce_scatter_multigpu(output_tensor_list, therefore ``len(input_tensor_lists[i])``) need to be the same for all the distributed processes calling this function. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2011,8 +2073,8 @@ def reduce_scatter_multigpu(output_tensor_list, opts = ReduceScatterOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.reduce_scatter( output_tensor_list, input_tensor_lists, @@ -2034,7 +2096,7 @@ def reduce_scatter_multigpu(output_tensor_list, def reduce_scatter(output, input_list, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces, then scatters a list of tensors to all processes in a group. @@ -2042,7 +2104,8 @@ def reduce_scatter(output, Arguments: output (Tensor): Output tensor. input_list (list[Tensor]): List of tensors to reduce and scatter. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2058,8 +2121,8 @@ def reduce_scatter(output, opts = ReduceScatterOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.reduce_scatter([output], [input_list], opts) else: work = group.reduce_scatter([output], [input_list], opts) @@ -2074,7 +2137,7 @@ def all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, - group=group.WORLD, + group=None, async_op=False): """ Each process splits input tensor and then scatters the split list @@ -2090,7 +2153,8 @@ def all_to_all_single(output, input_split_sizes: (list[Int], optional): Input split sizes for dim 0 if specified None or empty, dim 0 of ``input`` tensor must divide equally by ``world_size``. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2154,8 +2218,8 @@ def all_to_all_single(output, output_split_sizes = [] if output_split_sizes is None else output_split_sizes input_split_sizes = [] if input_split_sizes is None else input_split_sizes - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) else: work = group.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) @@ -2167,7 +2231,7 @@ def all_to_all_single(output, def all_to_all(output_tensor_list, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Each process scatters list of input tensors to all processes in a group and @@ -2177,7 +2241,8 @@ def all_to_all(output_tensor_list, output_tensor_list (list[Tensor]): List of tensors to be gathered one per rank. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2245,8 +2310,8 @@ def all_to_all(output_tensor_list, _check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list") - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.alltoall(output_tensor_list, input_tensor_list, opts) else: work = group.alltoall(output_tensor_list, input_tensor_list, opts) @@ -2257,7 +2322,7 @@ def all_to_all(output_tensor_list, work.wait() -def barrier(group=group.WORLD, +def barrier(group=GroupMember.WORLD, async_op=False): """ Synchronizes all processes. @@ -2266,7 +2331,8 @@ def barrier(group=group.WORLD, if async_op is False, or if async work handle is called on wait(). Arguments: - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -2276,8 +2342,8 @@ def barrier(group=group.WORLD, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.barrier() else: work = group.barrier() @@ -2297,6 +2363,17 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes. + .. warning:: + Using multiple process groups with the ``NCCL`` backend concurrently + is not safe and the user should perform explicit synchronization in + their application to ensure only one process group is used at a time. + This means collectives from one process group should have completed + execution on the device (not just enqueued since CUDA execution is + async) before collectives from another process group are enqueued. + See `Using multiple NCCL communicators concurrently `_ for more details. + Arguments: ranks (list[int]): List of ranks of group members. If ``None``, will be set to all ranks. Default is ``None``. @@ -2317,7 +2394,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): global _pg_group_ranks - default_pg = _check_default_pg() + default_pg = _get_default_group() default_backend, default_store = _pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 3ee7e9b2a4b0..678bbf6a96de 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -18,9 +18,9 @@ from torch import Tensor, device, dtype, nn from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.utils import _parse_remote_device +from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle -from torch.nn import Module _grad_t = Union[Tuple[Tensor, ...], Tensor] @@ -209,6 +209,10 @@ def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]: """ return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) + def get_module_rref(self) -> rpc.RRef[nn.Module]: + """Returns the RRef to remote module.""" + return self.module_rref + def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index 51e48d58de3a..dd963ab6f7a4 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -93,6 +93,7 @@ from .kl import kl_divergence, register_kl from .kumaraswamy import Kumaraswamy from .laplace import Laplace +from .lkj_cholesky import LKJCholesky from .log_normal import LogNormal from .logistic_normal import LogisticNormal from .lowrank_multivariate_normal import LowRankMultivariateNormal diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py new file mode 100644 index 000000000000..cdbfe5be55bb --- /dev/null +++ b/torch/distributions/lkj_cholesky.py @@ -0,0 +1,126 @@ +""" +This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). + +Original copyright notice: + +# Copyright: Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 +""" + +import math + +import torch +from torch.distributions import constraints, Beta +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + + +class LKJCholesky(Distribution): + r""" + LKJ distribution for lower Cholesky factor of correlation matrices. + The distribution is controlled by ``concentration`` parameter :math:`\eta` + to make the probability of the correlation matrix :math:`M` generated from + a Cholesky factor propotional to :math:`\det(M)^{\eta - 1}`. Because of that, + when ``concentration == 1``, we have a uniform distribution over Cholesky + factors of correlation matrices. Note that this distribution samples the + Cholesky factor of correlation matrices and not the correlation matrices + themselves and thereby differs slightly from the derivations in [1] for + the `LKJCorr` distribution. For sampling, this uses the Onion method from + [1] Section 3. + + L ~ LKJCholesky(dim, concentration) + X = L @ L' ~ LKJCorr(dim, concentration) + + Example:: + + >>> l = LKJCholesky(3, 0.5) + >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix + tensor([[ 1.0000, 0.0000, 0.0000], + [ 0.3516, 0.9361, 0.0000], + [-0.1899, 0.4748, 0.8593]]) + + Args: + dimension (dim): dimension of the matrices + concentration (float or Tensor): concentration/shape parameter of the + distribution (often referred to as eta) + + **References** + + [1] `Generating random correlation matrices based on vines and extended onion method`, + Daniel Lewandowski, Dorota Kurowicka, Harry Joe. + """ + arg_constraints = {'concentration': constraints.positive} + support = constraints.corr_cholesky + + def __init__(self, dim, concentration=1., validate_args=None): + if dim < 2: + raise ValueError(f'Expected dim to be an integer greater than or equal to 2. Found dim={dim}.') + self.dim = dim + self.concentration, = broadcast_all(concentration) + batch_shape = self.concentration.size() + event_shape = torch.Size((dim, dim)) + # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. + marginal_conc = self.concentration + 0.5 * (self.dim - 2) + offset = torch.arange(self.dim - 1, dtype=self.concentration.dtype, device=self.concentration.device) + offset = torch.cat([offset.new_zeros((1,)), offset]) + beta_conc1 = offset + 0.5 + beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset + self._beta = Beta(beta_conc1, beta_conc0) + super(LKJCholesky, self).__init__(batch_shape, event_shape, validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LKJCholesky, _instance) + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.concentration = self.concentration.expand(batch_shape) + new._beta = self._beta.expand(batch_shape + (self.dim,)) + super(LKJCholesky, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + # This uses the Onion method, but there are a few differences from [1] Sec. 3.2: + # - This vectorizes the for loop and also works for heterogeneous eta. + # - Same algorithm generalizes to n=1. + # - The procedure is simplified since we are sampling the cholesky factor of + # the correlation matrix instead of the correlation matrix itself. As such, + # we only need to generate `w`. + y = self._beta.sample(sample_shape).unsqueeze(-1) + u_normal = torch.randn(self._extended_shape(sample_shape), + dtype=y.dtype, + device=y.device).tril(-1) + u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) + # Replace NaNs in first row + u_hypersphere[..., 0, :].fill_(0.) + w = torch.sqrt(y) * u_hypersphere + # Fill diagonal elements; clamp for numerical stability + eps = torch.finfo(w.dtype).tiny + diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() + w += torch.diag_embed(diag_elems) + return w + + def log_prob(self, value): + # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html + # The probability of a correlation matrix is proportional to + # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) + # Additionally, the Jacobian of the transformation from Cholesky factor to + # correlation matrix is: + # prod(L_ii ^ (D - i)) + # So the probability of a Cholesky factor is propotional to + # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) + # with order_i = 2 * concentration - 2 + D - i + diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:] + order = torch.arange(2, self.dim + 1) + order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order + unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1) + # Compute normalization constant (page 1999 of [1]) + dm1 = self.dim - 1 + alpha = self.concentration + 0.5 * dm1 + denominator = torch.lgamma(alpha) * dm1 + numerator = torch.mvlgamma(alpha - 0.5, dm1) + # pi_constant in [1] is D * (D - 1) / 4 * log(pi) + # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) + # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 + pi_constant = 0.5 * dm1 * math.log(math.pi) + normalize_term = pi_constant + numerator - denominator + return unnormalized_log_pdf - normalize_term diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index a0412d52df0d..4181db799b28 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -733,6 +733,7 @@ class CatTransform(Transform): """ def __init__(self, tseq, dim=0, lengths=None, cache_size=0): assert all(isinstance(t, Transform) for t in tseq) + self.event_dim = max(t.event_dim for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] super(CatTransform, self).__init__(cache_size=cache_size) @@ -784,9 +785,20 @@ def log_abs_det_jacobian(self, x, y): for trans, length in zip(self.transforms, self.lengths): xslice = x.narrow(self.dim, start, length) yslice = y.narrow(self.dim, start, length) - logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) + logdetjac = trans.log_abs_det_jacobian(xslice, yslice) + if trans.event_dim < self.event_dim: + logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) + logdetjacs.append(logdetjac) start = start + length # avoid += for jit compat - return torch.cat(logdetjacs, dim=self.dim) + # Decide whether to concatenate or sum. + dim = self.dim + if dim >= 0: + dim = dim - x.dim() + dim = dim + self.event_dim + if dim < 0: + return torch.cat(logdetjacs, dim=dim) + else: + return sum(logdetjacs) @property def bijective(self): diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 380b98785f6c..84f45f1d33cf 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -118,8 +118,8 @@ def tril_matrix_to_vec(mat, diag=0): which comprises of lower triangular elements from the matrix in row order. """ n = mat.shape[-1] - if not torch._C._get_tracing_state() and (diag <= -n or diag >= n): - raise ValueError(f'diag ({diag}) provided is outside [{-n+1}, {n-1}].') + if not torch._C._get_tracing_state() and (diag < -n or diag >= n): + raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].') arange = torch.arange(n, device=mat.device) tril_mask = arange < arange.view(-1, 1) + (diag + 1) vec = mat[..., tril_mask] diff --git a/torch/functional.py b/torch/functional.py index cbdbdae66823..c2aabc64200c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -296,107 +296,76 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): def einsum(equation, *operands): r"""einsum(equation, *operands) -> Tensor - Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation - based on the Einstein summation convention. - - Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them - in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of - this format are described below, but the general idea is to label every dimension of the input :attr:`operands` - with some subscript and define which subscripts are part of the output. The output is then computed by summing - the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the - output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`. - Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). - - Equation: - - The :attr:`equation` string specifies the subscripts (lower case letters `['a', 'z']`) for each dimension of - the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a - comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript - must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is - repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand - must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that - appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. - The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based - on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. - - Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation - followed by the subscripts for the output. For instance, the following equation computes the transpose of a - matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and - at most once for the output. - - Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. - Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, - e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth - dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the - 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not - explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), - before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements - batch matrix multiplication `'...ij,...jk'`. - - A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, - arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. - - .. note:: - - ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions - covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. - - .. note:: - - This function does not optimize the given expression, so a different formula for the same computation may - run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) - can optimize the formula for you. - - Args: - equation (string): The subscripts for the Einstein summation. - operands (Tensor): The operands to compute the Einstein sum of. - - Examples:: - - # trace - >>> torch.einsum('ii', torch.randn(4, 4)) - tensor(-1.2104) - - # diagonal - >>> torch.einsum('ii->i', torch.randn(4, 4)) - tensor([-0.1034, 0.7952, -0.2433, 0.4545]) - - # outer product - >>> x = torch.randn(5) - >>> y = torch.randn(4) - >>> torch.einsum('i,j->ij', x, y) - tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], - [-0.3744, 0.9381, 1.2685, -1.6070], - [ 0.7208, -1.8058, -2.4419, 3.0936], - [ 0.1713, -0.4291, -0.5802, 0.7350], - [ 0.5704, -1.4290, -1.9323, 2.4480]]) - - # batch matrix multiplication - >>> As = torch.randn(3,2,5) - >>> Bs = torch.randn(3,5,4) - >>> torch.einsum('bij,bjk->bik', As, Bs) - tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], - [-1.6706, -0.8097, -0.8025, -2.1183]], - - [[ 4.2239, 0.3107, -0.5756, -0.2354], - [-1.4558, -0.3460, 1.5087, -0.8530]], - - [[ 2.8153, 1.8787, -4.3839, -1.2112], - [ 0.3728, -2.1131, 0.0921, 0.8305]]]) - - # batch permute - >>> A = torch.randn(2, 3, 4, 5) - >>> torch.einsum('...ij->...ji', A).shape - torch.Size([2, 3, 5, 4]) - - # equivalent to torch.nn.functional.bilinear - >>> A = torch.randn(3,5,4) - >>> l = torch.randn(2,5) - >>> r = torch.randn(2,4) - >>> torch.einsum('bn,anm,bm->ba', l, A, r) - tensor([[-0.3430, -5.2405, 0.4494], - [ 0.3311, 5.5201, -3.0356]]) - """ +This function provides a way of computing multilinear expressions (i.e. sums of products) using the +Einstein summation convention. + +Args: + equation (string): The equation is given in terms of lower case letters (indices) to be associated + with each dimension of the operands and result. The left hand side lists the operands + dimensions, separated by commas. There should be one index letter per tensor dimension. + The right hand side follows after `->` and gives the indices for the output. + If the `->` and right hand side are omitted, it implicitly defined as the alphabetically + sorted list of all indices appearing exactly once in the left hand side. + The indices not apprearing in the output are summed over after multiplying the operands + entries. + If an index appears several times for the same operand, a diagonal is taken. + Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred, + the ellipsis dimensions are at the beginning of the output. + operands (Tensor): The operands to compute the Einstein sum of. + +.. note:: + + This function does not optimize the given expression, so a different formula for the same computation may + run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) + can optimize the formula for you. + +Examples:: + + >>> x = torch.randn(5) + >>> y = torch.randn(4) + >>> torch.einsum('i,j->ij', x, y) # outer product + tensor([[-0.0570, -0.0286, -0.0231, 0.0197], + [ 1.2616, 0.6335, 0.5113, -0.4351], + [ 1.4452, 0.7257, 0.5857, -0.4984], + [-0.4647, -0.2333, -0.1883, 0.1603], + [-1.1130, -0.5588, -0.4510, 0.3838]]) + + + >>> A = torch.randn(3,5,4) + >>> l = torch.randn(2,5) + >>> r = torch.randn(2,4) + >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear + tensor([[-0.3430, -5.2405, 0.4494], + [ 0.3311, 5.5201, -3.0356]]) + + + >>> As = torch.randn(3,2,5) + >>> Bs = torch.randn(3,5,4) + >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication + tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], + [-1.6706, -0.8097, -0.8025, -2.1183]], + + [[ 4.2239, 0.3107, -0.5756, -0.2354], + [-1.4558, -0.3460, 1.5087, -0.8530]], + + [[ 2.8153, 1.8787, -4.3839, -1.2112], + [ 0.3728, -2.1131, 0.0921, 0.8305]]]) + + >>> A = torch.randn(3, 3) + >>> torch.einsum('ii->i', A) # diagonal + tensor([-0.7825, 0.8291, -0.1936]) + + >>> A = torch.randn(4, 3, 3) + >>> torch.einsum('...ii->...i', A) # batch diagonal + tensor([[-1.0864, 0.7292, 0.0569], + [-0.9725, -1.0270, 0.6493], + [ 0.5832, -1.1716, -1.5084], + [ 0.4041, -1.1690, 0.8570]]) + + >>> A = torch.randn(2, 3, 4, 5) + >>> torch.einsum('...ij->...ji', A).shape # batch permute + torch.Size([2, 3, 5, 4]) +""" if not torch.jit.is_scripting(): if any(type(t) is not Tensor for t in operands) and has_torch_function(operands): return handle_torch_function(einsum, operands, equation, *operands) @@ -464,10 +433,13 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, r"""Short-time Fourier transform (STFT). .. warning:: - From version 1.8.0, :attr:`return_complex` must be given explicitly for - real inputs. Set to True to return a complex output, or False to - preserve the legacy behavior of returning a real tensor with an extra - last dimension for the real and imaginary components. + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. + + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. The STFT computes the Fourier transform of short overlapping windows of the input. This giving frequency components of the signal as they change over @@ -1298,7 +1270,13 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa more consistent with NumPy's numpy.linalg.norm. Args: - input (Tensor): the input tensor + input (Tensor): The input tensor. Its data type must be either a floating + point or complex type. For complex inputs, the norm is calculated using the + absolute value of each element. If the input is complex and neither + :attr:`dtype` nor :attr:`out` is specified, the result's data type will + be the corresponding floating point type (e.g. float if :attr:`input` is + complexfloat). + p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` The following norms can be calculated: diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 43ec348d45e6..a995a58c5774 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -10,9 +10,8 @@ PartitionMode class DAGNode(): - """ - DAGNode class maintains useful information for a partition (submodule). - inputs(submodule node) and outputs(submodule node). + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. """ def __init__( self, @@ -48,7 +47,7 @@ def create_node( self.nodes.append(node) class PartitionResult(NamedTuple): - """NameTuple used for returning DAG and a new graph module + """NameTuple used for returning DAG and a new fx module """ dag: DAG module_with_submodules: GraphModule @@ -73,7 +72,6 @@ def combine_two_partitions( partitions.append(partition) partitions.remove(partition_0) partitions.remove(partition_1) - # Reorganize partitions reorganize_partitions(partitions) return @@ -92,7 +90,7 @@ def set_parents_and_children(partitions: List[Partition]) -> None: # For each node in the current partition, find its users users = node.users for n in users: - # Find which the partition the user belongs to. + # Find which the partition the user node belongs to. # Note that if the node itself is also belongs to that partition, # that partition is not the child of the current partition for p in partitions: @@ -103,7 +101,7 @@ def set_parents_and_children(partitions: List[Partition]) -> None: def reorganize_partitions(partitions: List[Partition]) -> None: """Given a list of partitions, reorganzie partiton id, - its parents and its children for each partition + its parents and its children for each partition """ # Rearrange partition ids for i, partition in enumerate(partitions): @@ -123,7 +121,7 @@ def get_bfs_level_partition(partitions: List[Partition]) -> None: current_level.add(partition) next_level: Set[Partition] = set() level = 0 - # Start bfs + # bfs while current_level: partition = current_level.pop() partition.bfs_level = level @@ -149,7 +147,7 @@ def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int def get_device_to_partitions_mapping(partitions: List[Partition], devices: List[Device]): """Given a list of partitions and a list of devices, - map each partition into a device. + map each partition into a device. """ def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[Partition]): all_nodes: Set[Node] = set() @@ -165,10 +163,10 @@ def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[ def find_device_for(partition: Partition): """Given a partition, find a logical device for the partition - The algorithm is that: - #1. sort all the devices based on left mem size - #2. put the partition on the device that has just enought mem - for that partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size """ for d in device_to_left_mem_bytes: extra_size_needed = calculate_extra_mem_bytes_needed_for(partition, device_to_partitions[d]) @@ -188,8 +186,8 @@ def find_device_for(partition: Partition): logical_id_to_device[d.logical_id] = d device_to_partitions[d] = [] device_to_left_mem_bytes[d] = d.available_mem_bytes - # Deal with the partitions that have a device - # Find all no device partitions + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) no_device_partitions = [] for partition in partitions: if partition.logical_device_ids != []: @@ -199,7 +197,7 @@ def find_device_for(partition: Partition): device_to_left_mem_bytes[device] = d.available_mem_bytes - partition.used_mem_bytes else: no_device_partitions.append(partition) - # Find device for each no device partition + # Find devices for all the partitions without a device found_device = True for partition in no_device_partitions: device_to_left_mem_bytes = { @@ -212,6 +210,9 @@ def find_device_for(partition: Partition): return found_device def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ visited: Set[Partition] = set([partition]) queue: List[Partition] = [partition] while queue: @@ -226,13 +227,13 @@ def check_dependency(partition): return False class Partitioner: - """A graph module may not fit into one device. - Partitioner class helps cut one graph into subgraphs (partitions), - so that each partition could fit into a different device. - The main function of this class is self.partition_graph. - It will partition the graph based on the scheme specified in partition_config - A DAG structure is returned - along with a new graph module with partitions as submodule nodes. + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. """ def __init__(self) -> None: self.partitions: List[Partition] = [] @@ -245,37 +246,40 @@ def partition_graph( torch_module: torch.nn.Module, partitioner_config: PartitionerConfig ) -> PartitionResult: - """ - Given the fx module, torch module and partitioner_config, - find the partitions, do the partitions, - and then return a DAG and a new fx module with submodule nodes (partitions) + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) """ self.graph_module = fx_module self.torch_module = torch_module self.devices = partitioner_config.devices if len(self.devices) == 0: raise RuntimeError('No devices') - # Check if there are op nodes in the graph + # Check if there are op nodes in the fx module nodes = self.graph_module.graph.nodes if all(node.op in {'placeholder', 'get_attr', 'output'} for node in nodes): raise RuntimeError('No Partition since no operations in the module') - # Calculate total size of the graph + # Calculate total size of the fx module total_size_of_graph = 0 for node in nodes: if node.op == 'output': break total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition if partitioner_config.mode == PartitionMode.aot_based: self.aot_based_partition( partitioner_config.node_to_partition_mapping, partitioner_config.partition_to_logical_device_mapping ) + # Single partition if the whole module can be fit into one device elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: self.find_single_partition(total_size_of_graph) elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): raise RuntimeError('Devices have no enough memory for the module') else: + # Sparse nn based partition if partitioner_config.mode == PartitionMode.sparse_nn: available_mem_bytes = self.devices[0].available_mem_bytes if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): @@ -283,11 +287,13 @@ def partition_graph( # sparse_nn_partition only support same memory size # TODO: add different size support for sparse_nn_partition self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition elif partitioner_config.mode == PartitionMode.cost_aware: self.cost_aware_partition( partitioner_config.transfer_rate_bytes_per_sec, partitioner_config.node_to_latency_mapping ) + # KL based partition elif partitioner_config.mode == PartitionMode.kl_based: self.kl_based_partition( partitioner_config.transfer_rate_bytes_per_sec, @@ -303,7 +309,8 @@ def partition_graph( return ret def find_single_partition(self, total_size_of_graph) -> None: - """Only one partition (one graph on one device).""" + """Fit the whole fx module into one device + """ partition_0 = self.create_partition() for node in self.graph_module.graph.nodes: if node.op == 'output': @@ -316,18 +323,18 @@ def find_single_partition(self, total_size_of_graph) -> None: return def size_based_partition(self) -> None: - """This method is to partition the graph based on memory size. + """This method is to partition the fx module based on memory size. It uses greedy approach. The result may not be the best. The basic idea is: Step 1: - Find a device which has enough memory to fit the first node, create a empty partition + Find a device which has enough memory to fit the current node, create a empty partition with the size of that device. Then keep adding the following nodes into the partition until the partition is full. Step 2: Repeat Step 1 until no device left Step 3: If some nodes are left, create a partition for each left node (single node partition). - and then try to map those partitions into logical devices with non single node partitions. + and then try to map those partitions into logical devices with enough mem left. """ def find_device_based_on_size(node) -> Device: """Given a node, this function is to find a logical device @@ -365,16 +372,18 @@ def find_device_based_on_size(node) -> Device: partition.logical_device_ids.append(device.logical_id) else: # The current partition is not the first partition - # Check if the current node can fit into this partition + # Check if the current node can fit into current partition if partition_to_left_mem_bytes[partition] < total_size_of_input_nodes: # Check if no device is left if len(self.partitions) == len(self.devices): - # No device left, all the partitions before are non single node partitions + # No device is left + # Put the previous partitions into a list (non_single_node_partitions) non_single_node_partitions = self.partitions[:] # Create the first single node partition for the current node self.create_single_node_partition(node) continue # Some devices are still left + # Create a new partition with a mem size that is enough for the current node device = find_device_based_on_size(node) partition = self.create_partition() total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) @@ -382,7 +391,7 @@ def find_device_based_on_size(node) -> Device: partition.logical_device_ids.append(device.logical_id) partition.add_node(node) partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes - # No device left, create single node partitions + # Create single node partitions if no device is left else: self.create_single_node_partition(node) reorganize_partitions(self.partitions) @@ -395,7 +404,7 @@ def find_device_based_on_size(node) -> Device: return def do_partition(self) -> GraphModule: - """Return a module with submodules (partitions).""" + """Return a new fx module with submodule nodes (partitions).""" module_with_submodules = split_module( self.graph_module, self.torch_module, @@ -404,6 +413,7 @@ def do_partition(self) -> GraphModule: return module_with_submodules def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules""" dag = DAG() for node in module_with_submodules.graph.nodes: if node.op == 'output': @@ -437,19 +447,21 @@ def create_partition(self) -> Partition: return partition def create_single_node_partition(self, node): - """Create a partition for a single node - """ + """Create a partition for a single node""" partition = self.create_partition() partition.add_node(node) return def sparse_nn_partition(self, available_mem_bytes: int) -> None: """This method partition a sparse nn module. - It first traverse all the nodes and do the partitions based on memory size. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. If the current partition has no enough memory left for a new op node (call_module, call_method, call_function), a new partition is created. - Different from size_based_partition, when traversing cross the boundary between - non-embedding nodes and embedding nodes, a new partition is created regardlessly. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. For example, if the current node is a non-embedding node but the next node is an embedding node, a new partition is created for the next node. After the partition, the partitions are combined as much as possible. @@ -470,7 +482,7 @@ def combine_partitions_based_on_size(partitions: List[Partition], available_mem_ We go from the largest and selection partition_0. Check the bfs level for two partitions, if the level difference is less than 2, it can be combined. - Then repeat step 1. + step 2: repeat step 1 until no partitions can be combined """ find_combination = True while find_combination: @@ -518,6 +530,9 @@ def find_partition_to_combine_based_on_size( return find_combination, partitions def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boudary between non-embedding nodes and + embedding nodes, create a new partition + """ if in_embedding_region: embedding_partitions.append(partition) else: @@ -604,9 +619,9 @@ def cost_aware_partition( node_to_latency_mapping: Dict[Node, NodeLatency] ) -> None: """This method is to partition the fx module based on the cost. - The cost is the total latency of running the whole graph. + The cost is the total latency of running the whole fx module. In partitioner_utils.py, the cost model is built. - The algorithm is: + The cost aware partition algorithm is: #1. At every begining, each node is a partition. Then we map all the partitions to the devices and calculate the cost @@ -623,7 +638,7 @@ def try_combining_partitions( p1_index, partitions ) -> float: - """Given two partitions and a list of partitions, try to combine these two partitions + """Given two partitions and a list of partitions, combine these two partitions and see what is the cost of the modified partition list """ p0 = partitions[p0_index] @@ -656,10 +671,10 @@ def search_combination( find two partitions to combine so the cost of the partitions can be reduced. The algorithm is : - 1. Going through all the partition pairs and see - if the pair of partitions can be combined. - 2. If they are combined, the cost is calculated. - 3. Select the minimum cost and combine its cooresponding partition pair + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its cooresponding partition pair. """ partition_to_latency_mapping = get_partition_to_latency_mapping(self.partitions, node_to_latency_mapping) cost = get_latency_of_partitioned_graph(self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) @@ -704,7 +719,7 @@ def search_combination( transfer_rate_bytes_per_sec, node_to_latency_mapping ) - # Make sure all partitions are set up correctly. + # Make sure all partitions are set up correctly reorganize_partitions(self.partitions) # Set up node to partition mapping self.node_to_partition = get_node_to_partition_mapping(self.partitions) @@ -725,7 +740,7 @@ def kl_based_partition( Using size_based_partition, n0 and n1 are in Partition p0. n2, n3 and n4 in Partition p1. The current cost is esimated. We first tried using n0 to swap with n2 from the other partiton. - Then we found swapping n0 and n2 shows a lower cost + Then we see that swapping n0 and n2 shows a lower cost than the current cost and it is the minimum among other pairs like (n0, None)(This means moving n0 to Partition without swapping other nodes), (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost @@ -828,7 +843,8 @@ def swap_node_to_partition(node, p0, p1, node_to_latency_mapping, transfer_rate_ node_to_latency_mapping, transfer_rate_bytes_per_sec ) - # Update cost and node pair + # Update the cost + # Track the swapped node pair and their partitions if new_cost < cost: cost = new_cost node_pair = new_node_pair diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py new file mode 100644 index 000000000000..4b5ea52c5c1f --- /dev/null +++ b/torch/fx/experimental/const_fold.py @@ -0,0 +1,269 @@ +import operator +from typing import Dict, Set, List, Optional + +import torch.fx +from torch.fx.experimental.subgraph_creation_example import split_module +import re + + +def _make_tuple(x): + """ + Helper to convert x into a one item tuple if it's not a tuple already. + """ + return x if isinstance(x, tuple) else (x,) + + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + const_output_names: Optional[List[str]] = None, + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.const_output_names = const_output_names + self.has_folding_been_run = False + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if self.const_subgraph_module is None or self.const_output_names is None: + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. We _make_tuple here because + # single attr const fold subgraphs output a single Tensor while + # multiple outputs are returned as Tuple[Tensor,]. + folded_attrs = _make_tuple(self.const_subgraph_module()) + + # Look for output node from const folding subgraph and set attrs on the + # module with the results. + for i in range(len(folded_attrs)): + setattr( + self, self.const_output_names[i], torch.nn.Parameter(folded_attrs[i]) + ) + + +def split_const_subgraphs( + module: torch.nn.Module, +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + mod_traced = torch.fx.symbolic_trace(module) + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: Set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op == "get_attr" or set(node.all_input_nodes).issubset(const_nodes): + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + # Gather all names that are output from the const folding subgraph, which we + # will need to set dummy params on the module. + const_output_names: List[str] = [] + for node in split.submod_0.graph.nodes: + if node.op == "output": + # Note: we _make_tuple here because the output Node either contains + # a single output Node, or Tuple[Node], so this simplifies things. + const_output_names = [o.name for o in _make_tuple(node.args[0])] + break + + # Make sure the attr name we want to use is uniquely named in the module. + for i in range(len(const_output_names)): + # Add a suffix to make it easier to tell these were the result of const folding. + name = const_output_names[i] + "__CF" + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + const_output_names[i] = name + + # Now track the const_output_names to what name is used in the parent graph + # from the split via call_function getitem, to see what order it is passed + # into the non-const subgraph submod_1. First look to the parent module + # containing/calling into the const/non-const submodules to determine what + # the inputs are to each. Note if submod_0 had a single output then there is + # no getitem, and we can simply use the output from the call to submoid_0. + call_submod_0_args, call_submod_1_args = None, None + orig_ph_targets: List[str] = [] + for node in split.graph.nodes: + if node.op == "placeholder": + orig_ph_targets.append(node.target) + + if node.op == "call_module": + if node.target == "submod_0": + call_submod_0_args = node.args + continue + elif node.target == "submod_1": + call_submod_1_args = node.args + continue + assert call_submod_0_args is not None and call_submod_1_args is not None + + # Look through the args for the call into submod_1, and find the args that + # come from submod_0. Also look for get_attrs fed directly from the parent + # split into submod_1, i.e. those attrs that are not constant folded. + submod_1_input_idx_to_folded_attr_name: Dict[int, str] = {} + submod_1_input_idx_to_unfolded_attr_name: Dict[int, str] = {} + for i, node in enumerate(call_submod_1_args): + const_output_name = None + # If we only had a single output from submod_0 then we simply look for + # the call_module into it. + if len(const_output_names) == 1: + if node.op == "call_module" and node.target == "submod_0": + const_output_name = const_output_names[0] + + # Else we had multiple outputs from submod_0, so we need to look for all + # getitems from the call to it. + else: + if ( + node.op == "call_function" + and node.target == operator.__getitem__ + and node.args[0].target == "submod_0" + ): + const_output_name = const_output_names[node.args[1]] + + # Now map from the index of the constant into calling submod_1 and map + # to the constant output name, which we use for swapping in getattrs + # instead of placeholders in submod_1. + if const_output_name is not None: + submod_1_input_idx_to_folded_attr_name[i] = const_output_name + elif node.op == "get_attr": + submod_1_input_idx_to_unfolded_attr_name[i] = node.target + + assert len(submod_1_input_idx_to_folded_attr_name) == len(const_output_names) + + # Now we have a mapping from const output names to the index they are passed + # into submod_1, so swap in getattrs for placeholders. + ph_idx = 0 + for node in split.submod_1.graph.nodes: + if node.op != "placeholder": + continue + is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys() + is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys() + if not is_folded_attr and not is_unfolded_attr: + ph_idx += 1 + continue + + const_output_name = ( + submod_1_input_idx_to_folded_attr_name[ph_idx] + if is_folded_attr + else submod_1_input_idx_to_unfolded_attr_name[ph_idx] + ) + if is_folded_attr: + assert not hasattr(mod_traced, const_output_name) + # Use a dummy param, which will be overwritten when we run const folding. + setattr( + mod_traced, + const_output_name, + torch.nn.Parameter(torch.randn(1)), + ) + with split.submod_1.graph.inserting_before(node): + node.replace_all_uses_with(split.submod_1.graph.get_attr(const_output_name)) + split.submod_1.graph.erase_node(node) + ph_idx += 1 + + # We may need to reorder placeholders to ensure they have the same order as + # they do in the original split. + ph_idx = 0 + node = next(iter(split.submod_1.graph.nodes)) + while node.op != "root": + if node.op != "placeholder": + node = node.next + continue + + curr_orig_ph_target = orig_ph_targets[ph_idx] + ph_idx += 1 + # If this ph is in the correct position, nothing to do. + if curr_orig_ph_target == node.target: + node = node.next + continue + + # This ph is not in the correct order, so search the rest of the graph + # for the ph we expected and prepend it before the current ph. + later_node = node.next + while later_node.op != "root": + if ( + later_node.op == "placeholder" + and curr_orig_ph_target == later_node.target + ): + break + later_node = later_node.next + assert later_node.op != "root" + node.prepend(later_node) + # Note we do not increment node here, as it still may be in the wrong + # place (we just prepended the ph that should have come before it). + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside submod_0, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to submod_0 must be constants accessible via get_attr. + for node in split.submod_0.graph.nodes: + if node.op != "placeholder": + continue + in_node = next(n for n in call_submod_0_args if n.name == node.target) + assert in_node.op == "get_attr" + with split.submod_0.graph.inserting_before(node): + node.replace_all_uses_with(split.submod_0.graph.get_attr(in_node.target)) + split.submod_0.graph.erase_node(node) + + return FoldedGraphModule( + mod_traced, split.submod_1.graph, split.submod_0.graph, const_output_names + ) diff --git a/torch/fx/experimental/graph_manipulation.py b/torch/fx/experimental/graph_manipulation.py index 2eea162faedb..4e6c23cbad9f 100644 --- a/torch/fx/experimental/graph_manipulation.py +++ b/torch/fx/experimental/graph_manipulation.py @@ -3,6 +3,7 @@ import torch from torch.fx.experimental.shape_prop import ShapeProp +from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes from torch.fx.graph import Graph, get_qualified_name from torch.fx.graph_module import GraphModule from torch.fx.node import Node, Target, map_arg @@ -122,19 +123,17 @@ def serialize_weight(tensor: torch.Tensor) -> Dict: def serialize_leaf_module( - mod: torch.nn.Module, weights_metadata: Dict, weights: Dict, name_prefix: str + node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str ) -> Dict: parameters: Dict[str, Any] = {} - parameters["name"] = type(mod).__name__ - for name, buffer in mod.named_buffers(): - weights_metadata[f"{name_prefix}.{name}"] = serialize_weight(buffer) - weights[f"{name_prefix}.{name}"] = buffer - for name, parameter in mod.named_parameters(): - weights_metadata[f"{name_prefix}.{name}"] = serialize_weight(parameter) - weights[f"{name_prefix}.{name}"] = parameter - if isinstance(mod.__constants__, List): - for constant in mod.__constants__: - parameters[constant] = str(getattr(mod, constant)) + + for p_name, p_value in node.attrs_for_lowering.items(): # type: ignore + if isinstance(p_value, torch.Tensor): + weights_metadata[f"{name_prefix}.{p_name}"] = serialize_weight(p_value) + weights[f"{name_prefix}.{p_name}"] = p_value + else: + parameters[p_name] = str(p_value) + return parameters @@ -187,6 +186,7 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D weight = serialize_weight(p) serialized_dict["weights"][prefix + name] = weight weights[prefix + name] = p + lift_lowering_attrs_to_nodes(fx_module) for node in fx_module.graph.nodes: node_rep: Dict[str, Any] = {} # Get shape/type info, currently not needed for call_module. @@ -217,7 +217,7 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D serialized_dict["modules"][node.target] = serialized_module else: node_rep["parameters"] = serialize_leaf_module( - submodules[node.target], + node, serialized_dict["weights"], weights, prefix + node.target, diff --git a/torch/fx/experimental/param_fetch.py b/torch/fx/experimental/param_fetch.py new file mode 100644 index 000000000000..6bce29b97e78 --- /dev/null +++ b/torch/fx/experimental/param_fetch.py @@ -0,0 +1,60 @@ +from torch.fx.graph_module import GraphModule +from typing import Any, Callable, Dict, List, Tuple, Type +import torch +import torch.nn as nn + + +# Matching method matches the attribute name of current version to the attribute name of `target_version` +def default_matching(name: str, target_version: int) -> str: + """Default matching method + """ + return name + +# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. +# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. +# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. +module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { + torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), + torch.nn.modules.conv.Conv2d: ( + 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching + ), + torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), + torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), + torch.nn.modules.pooling.MaxPool2d: ( + 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching + ), + torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), +} + +def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: + """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` + after checking module's version is compatible with the `module_fetch_book`. + """ + attrs_for_lowering: Dict[str, Any] = {} + attrs_for_lowering["name"] = torch.typename(mod) + + if type(mod) in module_fetch_book: + version, param_to_fetch, matching_method = module_fetch_book[type(mod)] + if version < mod._version: + raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + for attr in param_to_fetch: + attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) + else: + raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + return attrs_for_lowering + +def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. + """ + submodules = dict(fx_module.named_modules()) + + for node in fx_module.graph.nodes: + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + lift_lowering_attrs_to_nodes(submodules[node.target]) + else: + node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 072aef6e3b93..e6fc19a1394e 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -148,26 +148,7 @@ def forward(self, x): %topk_1 : [#users=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 return topk_1 - The Node semantics are as follows: - - - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. - ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument - denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to - the function parameters (e.g. ``x``) in the graph printout. - - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the - fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. - ``args`` and ``kwargs`` are don't-care - - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign - to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, - following the Python calling convention - - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is - as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. - ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. - - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method - to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, - *including the self argument* - - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement - in the Graph printout. + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. """ def __init__(self): """ @@ -568,7 +549,7 @@ def illegal_shadowing_name(name : str) -> bool: _shadows_builtin_name(name) while candidate in self._used_names or illegal_shadowing_name(candidate): - match = re.match(r"(.*)_(\d+)", candidate) + match = re.match(r"(.*)_(\d+)$", candidate) if match is None: candidate = candidate + '_1' else: @@ -636,10 +617,15 @@ def delete_unused_values(user : Node): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ + if user.op == 'output': + body.append('\n') + return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): to_delete_str = ' = '.join([n.name for n in nodes_to_delete] + ['None']) - body.append(f'{to_delete_str}\n') + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') def emit_node(node : Node): if node.op == 'placeholder': @@ -649,20 +635,20 @@ def emit_node(node : Node): free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') raw_name = node.target.replace('*', '') if raw_name != node.name: - body.append(f'{node.name} = {raw_name}\n') + body.append(f'{node.name} = {raw_name}') return elif node.op == 'call_method': assert isinstance(node.target, str) body.append( f'{node.name} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})\n') + f'({_format_args(node.args[1:], node.kwargs)})') return elif node.op == 'call_function': assert callable(node.target) # pretty print operators if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n') + body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') return qualified_name = get_qualified_name(node.target) register_modules_used(qualified_name) @@ -671,26 +657,28 @@ def emit_node(node : Node): isinstance(node.args[1], str) and \ node.args[1].isidentifier(): # pretty print attribute access - body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}\n') + body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}') return - body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})\n') + body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})') return elif node.op == 'call_module': assert isinstance(node.target, str) - body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n') + body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') return elif node.op == 'get_attr': assert isinstance(node.target, str) - body.append(f'{node.name} = {_format_target(root_module, node.target)}\n') + body.append(f'{node.name} = {_format_target(root_module, node.target)}') return elif node.op == 'output': if node.type is not None: maybe_return_annotation = f" -> {type_repr(node.type)}" - body.append(f'return {repr(node.args[0])}\n') + body.append(f'return {repr(node.args[0])}') return raise NotImplementedError(f'node: {node.op} {node.target}') for node in self.nodes: + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one emit_node(node) delete_unused_values(node) diff --git a/torch/fx/node.py b/torch/fx/node.py index 1cc94be83e7e..d304a4c0a472 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -21,8 +21,34 @@ ]] class Node: - def __init__(self, graph: 'Graph', name: str, op: str, target: Target, - args: Tuple[Argument, ...], kwargs: Dict[str, Argument], + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', + args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], type : Optional[Any] = None) -> None: self.graph = graph self.name = name # unique name of value being created @@ -60,23 +86,33 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, @property def next(self) -> 'Node': """ - Get the next node in the linked list + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. """ return self._next @property def prev(self) -> 'Node': """ - Get the previous node in the linked list + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. """ return self._prev - def prepend(self, x: 'Node'): - """Insert x before this node in the list of nodes in the graph. - Before: p -> self - bx -> x -> ax - After: p -> x -> self - bx -> ax + def prepend(self, x: 'Node') -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax Args: x (Node): The node to put before this node. Must be a member of the same graph. @@ -87,8 +123,9 @@ def prepend(self, x: 'Node'): p._next, x._prev = x, p x._next, self._prev = self, x - def append(self, x: 'Node'): - """Insert x after this node in the list of nodes in the graph. + def append(self, x: 'Node') -> None: + """ + Insert x after this node in the list of nodes in the graph. Equvalent to ``self.next.prepend(x)`` Args: @@ -103,9 +140,12 @@ def _remove_from_list(self): @property def args(self) -> Tuple[Argument, ...]: """ - Return the tuple of arguments to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. """ return self._args @@ -121,9 +161,12 @@ def args(self, a : Tuple[Argument, ...]): @property def kwargs(self) -> Dict[str, Argument]: """ - Return the dict of kwargs to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. """ return self._kwargs @@ -141,7 +184,12 @@ def all_input_nodes(self) -> List['Node']: """ Return all Nodes that are inputs to this Node. This is equivalent to iterating over ``args`` and ``kwargs`` and only collecting the values that - are Nodes + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. """ all_nodes : List['Node'] = [] map_arg(self.args, lambda n: all_nodes.append(n)) @@ -149,6 +197,9 @@ def all_input_nodes(self) -> List['Node']: return all_nodes def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): + """ + This API is internal. Do *not* call it directly. + """ self._args = new_args self._kwargs = new_kwargs @@ -168,7 +219,14 @@ def __repr__(self) -> str: def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. - Returns the list of nodes on which this change was made. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + + Returns: + + The list of Nodes on which this change was made. """ to_process = list(self.users) for use_node in to_process: @@ -190,9 +248,12 @@ def maybe_replace_node(n : Node) -> Node: def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ - if isinstance(a, tuple): + if isinstance(a, tuple) and hasattr(a, '_fields'): + elements = tuple(map_arg(elem, fn) for elem in a) + return type(a)(*elements) # type: ignore + elif isinstance(a, tuple): return tuple(map_arg(elem, fn) for elem in a) - if isinstance(a, list): + elif isinstance(a, list): return immutable_list(map_arg(elem, fn) for elem in a) elif isinstance(a, dict): return immutable_dict((k, map_arg(v, fn)) for k, v in a.items()) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index f8c4aa8d8366..ce406aa787ee 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -50,7 +50,13 @@ def create_arg(self, a: Any) -> Argument: Can be override to support more trace-specific types. """ # aggregates - if isinstance(a, (tuple, list)): + if isinstance(a, tuple) and hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = tuple(self.create_arg(elem) for elem in a) + return type(a)(*args) # type: ignore + elif isinstance(a, (tuple, list)): return type(a)(self.create_arg(elem) for elem in a) elif isinstance(a, dict): r = {} diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index 6bdc8dd1070b..69e3c708dde3 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -1,6 +1,6 @@ import inspect from types import CodeType, FunctionType -from typing import Any, Dict, Optional, List, Callable, Union +from typing import Any, Dict, Optional, Tuple, List, Callable, Union import torch from torch._C import ScriptObject # type: ignore @@ -51,21 +51,31 @@ class Tracer(TracerBase): def __init__(self): super().__init__() - def create_arg(self, a: Any) -> Argument: + def create_arg(self, a: Any) -> 'Argument': """ A method to specify the behavior of tracing when preparing values to be used as arguments to nodes in the ``Graph``. By default, the behavior includes: - - Iterate through collection types (e.g. tuple, list, dict) and recursively - call ``create_args`` on the elements. - - Given a Proxy object, return a reference to the underlying IR ``Node`` - - Given a non-Proxy Tensor object, emit IR for various cases: - - For a Parameter, emit a ``get_attr`` node referring to that Parameter - - For a non-Parameter Tensor, store the Tensor away in a special - attribute referring to that attribute. + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` """ # The base tracer is used to construct Graphs when there is no associated # module hierarchy, so it can never create parameter references. @@ -115,28 +125,32 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo their constituent ops are recorded, unless specified otherwise via this parameter. - Args - m - The module itself - module_qualified_name - The path to root of this module. For example, - if you have a module hierarchy where submodule ``foo`` contains - submodule ``bar``, which contains submodule ``baz``, that module will - appear with the qualified name ``foo.bar.baz`` here. + Args: + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) - def path_of_module(self, mod) -> str: + def path_of_module(self, mod : torch.nn.Module) -> str: """ Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if ``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function will return the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. """ for n, p in self.root.named_modules(): if mod is p: return n raise NameError('module is not installed as a submodule') - def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs): + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: """ Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -149,6 +163,20 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwa This method can be overridden to--for example--create nested traced GraphModules, or any other behavior you would want while tracing across ``Module`` boundaries. + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. """ module_qualified_name = self.path_of_module(m) if not self.is_leaf_module(m, module_qualified_name): @@ -205,6 +233,16 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` can either be an ``nn.Module`` instance or a Python callable. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. """ if isinstance(root, torch.nn.Module): self.root = root diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 78c226ab1739..57a66dde3a4c 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -417,12 +417,9 @@ def build_AnnAssign(ctx, stmt): @staticmethod def build_Delete(ctx, stmt): - if len(stmt.targets) > 1: - source_range = ctx.make_range(stmt.lineno, stmt.col_offset, - stmt.col_offset + len("del")) - raise NotSupportedError( - source_range, 'del with more than one operand is not supported') - return Delete(build_expr(ctx, stmt.targets[0])) + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del")) + + return Delete(r, [build_expr(ctx, target) for target in stmt.targets]) @staticmethod def build_Return(ctx, stmt): diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 1aa161511b01..19085f155020 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -452,7 +452,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( ncclCommCounter_(0), terminateProcessGroup_(false), opTimeout_(options->opTimeout), - futureNCCLCallbackStreams_(c10::cuda::device_count()), isHighPriorityStream_(options->isHighPriorityStream) { TORCH_CHECK(at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); @@ -462,7 +461,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( if (blockingWait_ && asyncErrorHandling_) { LOG(INFO) << "[Rank " << rank_ << "] NCCL_BLOCKING_WAIT and NCCL_ASYNC_ERROR_HANDLING " - << "should not both be enabled. " + << "should not both be enabled. " << "Only NCCL_BLOCKING_WAIT is being used in this process."; asyncErrorHandling_ = false; } @@ -867,15 +866,6 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // Creates the NCCL streams streamVal.push_back(at::cuda::getStreamFromPool(isHighPriorityStream_)); - - // If not set before, get a dedicated stream for the device to run - // FutureNCCL then callbacks. - std::lock_guard lock(mutex_); - if (futureNCCLCallbackStreams_[deviceIndex] == nullptr) { - futureNCCLCallbackStreams_[deviceIndex] = - std::make_shared( - at::cuda::getStreamFromPool(isHighPriorityStream_)); - } } // [Note 2 ] @@ -1018,17 +1008,7 @@ std::vector ProcessGroupNCCL::WorkNCCL::result() { c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: getFuture() { - TORCH_INTERNAL_ASSERT( - outputs_->size() == 1, - "WorkNCCL's getFuture API is only supported for single-process single-device mode."); - auto deviceIndex = (*outputs_)[0].device().index(); - // Create a new FutureNCCL object after checking for single-process - // single-device mode. - return c10::make_intrusive( - at::IValue(*outputs_), - deviceIndex, - cudaEvents_, - futureNCCLCallbackStreams_[deviceIndex]); + return future_; } void ProcessGroupNCCL::workEnqueue( @@ -1066,21 +1046,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( bool can_profile = outputs.size() == 1; auto work = initWork(devices, rank_, opType, can_profile ? profilingTitle : nullptr); - // Store references to outputs and futureNCCLCallbackStream to be used by - // WorkNCCL::getFuture. + // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; - - if (work->recordFunctionEndCallback_) { - // recordFunctionEndCallback_ is normally called in fininsh() function by - // base class, but since finish is not called by WorkNCCL, we schedule this - // function to be run when work is done. - // Note when can_profile is false, profilingTitle is not provided and so, - // recordFunctionEndCallback_ is not set. - work->getFuture()->addCallback(std::move(work->recordFunctionEndCallback_)); - } - - at::cuda::OptionalCUDAGuard gpuGuard; @@ -1121,11 +1088,29 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->ncclComms_[i] = ncclComms[i]; } + { + at::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + // Set appropriate work parameters. work->blockingWait_ = blockingWait_; work->opTimeout_ = opTimeout_; work->store_ = store_; + if (work->recordFunctionEndCallback_) { + // recordFunctionEndCallback_ is normally called in fininsh() function by + // base class, but since finish is not called by WorkNCCL, we schedule this + // function to be run when work is done. Note that addCallback() onto the + // Work's CUDAFuture is not useful here, as it would just run the callback + // inline. + // Note when can_profile is false, profilingTitle is not provided and so, + // recordFunctionEndCallback_ is not set. + work->recordFunctionEndCallback_(); + } + if (asyncErrorHandling_) { workEnqueue(work); } @@ -1154,10 +1139,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( auto work = initWork(devices, rank_, opType); if (opType == OpType::RECV) { - // Store references to outputs and futureNCCLCallbackStream to be used by - // WorkNCCL::getFuture. + // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(tensors); - work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; } at::cuda::OptionalCUDAGuard gpuGuard; @@ -1202,6 +1185,13 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( work->store_ = store_; } + if (opType == OpType::RECV) { + at::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + return work; } diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index fd57f105df0b..4d9dc3bd1ae8 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include #include @@ -108,7 +110,7 @@ class ProcessGroupNCCL : public ProcessGroup { bool finishedGPUExecution(); // Get a Future object that will be marked as completed internally. - // It actually returns a FutureNCCL object which is a sub class Future. + // It actually returns a CUDAFuture object which is a sub class of Future. c10::intrusive_ptr getFuture() override; // Helper function that sets an exception_ptr on the WorkNCCL object. @@ -168,12 +170,12 @@ class ProcessGroupNCCL : public ProcessGroup { // to the store. c10::intrusive_ptr store_; - // Store a reference to NCCL collective's outputs to be used by getFuture. + // Store a reference to NCCL collective's outputs, used by result and to + // give a more descriptive message when representing the Work as a string. std::shared_ptr> outputs_; - // Store streams that run FutureNCCL then callbacks. - std::vector> - futureNCCLCallbackStreams_; + // The future returned by getFuture. + c10::intrusive_ptr future_; friend class ProcessGroupNCCL; }; @@ -192,202 +194,6 @@ class ProcessGroupNCCL : public ProcessGroup { bool isHighPriorityStream; }; - // FutureNCCL is a subclass of ivalue's Future. The goal is to use - // this class in getFuture API of WorkNCCL. This Future is mostly a - // wrapper to synchronize streams appropriately and it mostly enables - // the async programming model of CUDA while trying to adhere to the - // Future interface. FutureNCCL does not support NCCL_BLOCKING_WAIT flag - // or NCCL's barrier(). - // - // If created by WorkNCCL's getFuture API, FutureNCCL has a reference to - // WorkNCCL's cudaEvents, NCCL collective's outputs, device index of - // outputs' device, and the ProcesGroupNCCL's dedicated - // futureNCCLCallbackStream for outputs' device that runs all the then - // callbacks called from this FutureNCCL. Its value is NCCL collective's - // outputs. FutureNCCL only supports single-process single-device mode where - // the size of outputs is equal to 1. - // - // If created by FutureNCCL's then callback, its value becomes the value of - // callback() and its cudaEvents will record the NCCL stream that runs that - // callback. Before invoking the callback, FutureNCCL will synchronize its - // own cudaEvents with the stream that runs the callback. This design - // enables synchronizing the appropriate streams and avoids stalling PyTorch's - // default stream while running the callback. In case of multiple then - // callbacks, the design will work like a chain such that FutureNCCL n will - // wait on the cudaEvents from FutureNCCL n - 1. All callbacks are executed on - // outputs' device's dedicated futureNCCLCallbackStream. - struct FutureNCCL : at::ivalue::Future { - public: - explicit FutureNCCL( - at::IValue value, - c10::DeviceIndex deviceIndex, - std::shared_ptr> cudaEvents, - std::shared_ptr futureNCCLCallbackStream) - : at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), - value_(std::move(value)), - deviceIndex_(deviceIndex), - cudaEvents_(cudaEvents), - futureNCCLCallbackStream_(futureNCCLCallbackStream) { - TORCH_INTERNAL_ASSERT( - cudaEvents_->size() == 1, - "FutureNCCL only supports single-process single-device mode."); - } - - // This constructor is used by then callback, it skips setting the value at - // the beginning. Later, the value will be set using markCompleted with the - // return value of callback. - explicit FutureNCCL( - c10::DeviceIndex deviceIndex, - std::shared_ptr> cudaEvents, - std::shared_ptr futureNCCLCallbackStream) - : at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), - deviceIndex_(deviceIndex), - cudaEvents_(cudaEvents), - futureNCCLCallbackStream_(futureNCCLCallbackStream) { - TORCH_INTERNAL_ASSERT( - cudaEvents_->size() == 1, - "FutureNCCL only supports single-process single-device mode."); - } - - // Gets the current stream of the device and synchronizes recorded streams - // with that. It will return after synchronizing the correct GPU streams to - // ensure we can have async CUDA execution and it does not wait for the - // entire operation to complete on GPU. - void wait() override { - if (error_) { - throw *error_; - } - auto stream = at::cuda::getCurrentCUDAStream(deviceIndex_); - (*cudaEvents_)[0].block(stream); - } - - // If FutureNCCL was created by FutureNCCL::then, its value would be empty - // initially. FutureNCCL::then will later use this method to set its value - // to the return value of the callback. - void markCompleted(at::IValue value) override { - TORCH_INTERNAL_ASSERT( - value_.isNone(), - "Attempting to set value of a FutureNCCL which has a value." - "FutureNCCL's value was internally set to NCCL collective's " - "outputs or the return value of the callback."); - value_ = std::move(value); - } - - // Just returns FutureNCCL's value after wait returns. - at::IValue value() override { - TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.") - wait(); - return value_; - } - - const at::IValue& constValue() override { - TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.") - wait(); - return value_; - } - - // Adds a callback to FutureNCCL. It invokes the callback inline after - // synchronizing FutureNCCL's own cudaEvents with the stream that runs - // this callback. This new FutureNCCL's cudaEvents will record the - // callback's stream and will have the result value of the callback. - void addCallback(std::function callback) override { - (*cudaEvents_)[0].block(*futureNCCLCallbackStream_); - c10::OptionalStreamGuard streamGuard{ - c10::Stream(*futureNCCLCallbackStream_)}; - callback(); - } - - // Adds a callback to FutureNCCL, and returns another FutureNCCL to hold - // the return value of the callback and new cudaEvents that recorded the - // stream that runs this callback. - c10::intrusive_ptr then( - std::function callback, - at::TypePtr /* unused */) override { - // Create a new cudaEvents object of size 1 that will record - // futureNCCLCallbackStream_ after callback and will be passed to the new - // FutureNCCL. - auto thenFutCudaEvents = - std::make_shared>(1); - // Create a FutureNCCL without setting a value. - auto fut = c10::make_intrusive( - deviceIndex_, thenFutCudaEvents, futureNCCLCallbackStream_); - - // Do not free the underlying data storage of value_ before its - // usage on futureNCCLCallbackStream_ finish. - if (record_stream_cb_ != nullptr) { - // If a Python communication hook is used, record_stream_cb_ will be - // set in torch/csrc/jit/python/pybind_utils.h, which allows Python - // dependency to be imported. - record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap()); - } else { - // If a C++ communication hook is used, create and set a record stream - // callback. - TORCH_INTERNAL_ASSERT( - value_.isTensorList() || value_.isTensor(), - "the future value must be either a tensor list or a tensor."); - at::Tensor tensor; - if (value_.isTensorList()) { - const auto tensors = value_.toTensorVector(); - TORCH_INTERNAL_ASSERT( - tensors.size() == 1, "expected exactly 1 tensor"); - tensor = tensors[0]; - } else { - tensor = value_.toTensor(); - } - c10::cuda::CUDACachingAllocator::recordStream( - tensor.storage().data_ptr(), *futureNCCLCallbackStream_); - } - - // Use the dedicated callback stream to run callback. - // Cannot move capture std::function in lambda, because it cannot deduce - // the template type for std::function. Hence use std::bind to explicitly - // specify types. - addCallback(std::bind( - [&](std::function cb) { - try { - fut->markCompleted(at::IValue(cb())); - // In case of chained then callback calls, thenFutCudaEvents - // records callback's stream. - (*thenFutCudaEvents)[0].record(*futureNCCLCallbackStream_); - } catch (const std::exception& e) { - fut->setError(std::current_exception()); - } - }, - std::move(callback))); - return fut; - } - - // Checks cudaEventQuery with cudaEvents. Returns true if a FutureError was - // recorded or the entire operation is completed on the GPU. - bool completed() const override { - if (error_) { - return true; - } - // Checking the work's corresponding CUDA events' status - auto ret = cudaEventQuery((*cudaEvents_)[0]); - return ret != cudaErrorNotReady || ret == cudaSuccess; - } - - bool hasValue() const override { - return !value_.isNone(); - } - - void setRecordStreamCallback( - std::function - record_stream_cb) override { - record_stream_cb_ = std::move(record_stream_cb); - } - - private: - at::IValue value_; - c10::DeviceIndex deviceIndex_; - std::shared_ptr> cudaEvents_; - std::shared_ptr futureNCCLCallbackStream_; - std::function - record_stream_cb_; - c10::optional error_; - }; - // If you wish to create multiple process groups, each with a potentially // different rank and size, you can do so by passing a new store instance // to each one. If you have only a single store object, you can @@ -723,16 +529,6 @@ class ProcessGroupNCCL : public ProcessGroup { // set contains the string representation of ncclUniqueId. std::unordered_set abortedComms_; - // In single-process single-device mode, WorkNCCL::getFuture is supported. - // Depending on the device index of collective outputs, WorkNCCL will pass - // the corresponding device's then callback stream to FutureNCCL. - // We just inititalize futureNCCLCallbackStreams_ inside the constructor and - // set its size to the total number of available devices and depending on the - // device of the NCCL collective's outputs, we later set the callback stream - // of the corresponding device inside ProcessGroupNCCL::getNCCLComm if not set - // before. - std::vector> futureNCCLCallbackStreams_; - // Schedule NCCL operations on high priority CUDA streams. bool isHighPriorityStream_ = false; diff --git a/torch/lib/c10d/test/CUDATest.cu b/torch/lib/c10d/test/CUDATest.cu index c47b29ea536d..88f87492206c 100644 --- a/torch/lib/c10d/test/CUDATest.cu +++ b/torch/lib/c10d/test/CUDATest.cu @@ -17,6 +17,7 @@ __global__ void waitClocks(const uint64_t count) { void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks) { waitClocks<<<1, 1, 0, stream.stream()>>>(clocks); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } int cudaNumDevices() { diff --git a/torch/library.h b/torch/library.h index 41178fd10e07..ac936d29c520 100644 --- a/torch/library.h +++ b/torch/library.h @@ -643,7 +643,7 @@ class TorchLibraryInit final { /// for any given namespace. #define TORCH_LIBRARY(ns, m) \ static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \ + static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \ torch::Library::DEF, \ &TORCH_LIBRARY_init_ ## ns, \ #ns, c10::nullopt, __FILE__, __LINE__ \ @@ -669,7 +669,7 @@ class TorchLibraryInit final { /// that it can only be called once for a given namespace. #define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ static void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid) (torch::Library&); \ - static torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _, uid) ( \ + static const torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _, uid) ( \ torch::Library::FRAGMENT, \ &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid), \ #ns, c10::nullopt, __FILE__, __LINE__ \ @@ -725,7 +725,7 @@ class TorchLibraryInit final { /// and dispatch key in the same translation unit. #define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \ static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid) (torch::Library&); \ - static torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ## _, uid) ( \ + static const torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ## _, uid) ( \ torch::Library::IMPL, \ c10::guts::if_constexpr( \ []() { return & C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid); }, \ diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 85b2e0754b05..c3b99dc7abf0 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -286,7 +286,10 @@ Args: input (Tensor): The input tensor. If dim is None, x must be 1-D or 2-D, unless :attr:`ord` is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D - will be returned. + will be returned. Its data type must be either a floating point or complex type. For complex + inputs, the norm is calculated on of the absolute values of each element. If the input is + complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will + be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat). ord (int, float, inf, -inf, 'fro', 'nuc', optional): The order of norm. inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object. @@ -409,7 +412,9 @@ times the matrix norm of the inverse of :attr:`input`. And for norms ``p = {None, 2, -2}`` this is defined as the ratio between the largest and smallest singular values. -This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. +This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. +If the input is complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will +be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat). .. note:: For ``p = {None, 2, -2}`` the condition number is computed as the ratio between the largest and smallest singular values computed using :func:`torch.linalg.svd`. diff --git a/torch/nn/intrinsic/quantized/modules/conv_relu.py b/torch/nn/intrinsic/quantized/modules/conv_relu.py index 76407062511f..8dd931ff05a8 100644 --- a/torch/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -16,7 +16,7 @@ class ConvReLU1d(nnq.Conv1d): Same as torch.nn.quantized.Conv1d """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -55,7 +55,7 @@ class ConvReLU2d(nnq.Conv2d): Same as torch.nn.quantized.Conv2d """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -94,7 +94,7 @@ class ConvReLU3d(nnq.Conv3d): Attributes: Same as torch.nn.quantized.Conv3d """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index d0cfa5f80512..4b07682b1af7 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -831,7 +831,7 @@ def extra_repr(self) -> str: class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. - See reference: Attention Is All You Need + See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O @@ -849,7 +849,7 @@ class MultiheadAttention(Module): vdim: total number of features in value. Default: None. Note: if kdim and vdim are None, they will be set to embed_dim such that - query, key, and value have the same number of features. + query, key, and value have the same number of features. Examples:: diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index e76e307d36a6..48e58d637ea6 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -434,8 +434,14 @@ class SyncBatchNorm(_BatchNorm): >>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) @@ -564,8 +570,14 @@ def convert_sync_batchnorm(cls, module, process_group=None): >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) """ diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 33f2a84aed74..f22c35fa39ff 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -16,7 +16,7 @@ from typing import Optional, List, Tuple convolution_notes = \ - {"groups_note": """* :attr:`groups` controls the connections between inputs and outputs. + {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by :attr:`groups`. For example, @@ -27,14 +27,14 @@ concatenated. * At groups= :attr:`in_channels`, each input channel is convolved with its own set of filters (of size - :math:`\\frac{\\text{out\_channels}}{\\text{in\_channels}}`).""", # noqa: W605 + :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""", - "depthwise_separable_note": """When `groups == in_channels` and `out_channels == K * in_channels`, + "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, where `K` is a positive integer, this operation is also known as a "depthwise convolution". In other words, for an input of size :math:`(N, C_{in}, L_{in})`, a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments - :math:`(C_\\text{in}=C_\\text{in}, C_\\text{out}=C_\\text{in} \\times \\text{K}, ..., \\text{groups}=C_\\text{in})`."""} # noqa: W605,B950 + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""} # noqa: B950 diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 3c8c626047dc..8256b111b988 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -15,12 +15,15 @@ class PixelShuffle(Module): `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ by Shi et. al (2016) for more details. + Note that this function can take inputs with any number of batch dimensions: + :math:`(L, H_{in}, W_{in})`, :math:`(N, L, H_{in}, W_{in})`, :math:`(N_1, N_2, L, H_{in}, W_{in})`, etc. + Args: upscale_factor (int): factor to increase spatial resolution by Shape: - - Input: :math:`(N, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2` - - Output: :math:`(N, C, H_{out}, W_{out})` where + - Input: :math:`(*, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2` + - Output: :math:`(*, C, H_{out}, W_{out})` where :math:`H_{out} = H_{in} \times \text{upscale\_factor}` and :math:`W_{out} = W_{in} \times \text{upscale\_factor}` diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 331d3885bd30..dacd74a2fba0 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -2,10 +2,9 @@ import torch from torch.cuda import nccl from torch._utils import _take_tensors, _flatten_dense_tensors, \ - _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index + _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index, _handle_complex from typing import List - def broadcast(tensor, devices=None, *, out=None): r"""Broadcasts a tensor to specified GPU devices. @@ -27,6 +26,7 @@ def broadcast(tensor, devices=None, *, out=None): a tuple containing :attr:`out` tensors, each containing a copy of :attr:`tensor`. """ + tensor = _handle_complex(tensor) if not ((devices is None) ^ (out is None)): raise RuntimeError( "Exactly one of 'devices' and 'out' must be specified, but got " @@ -54,6 +54,7 @@ def broadcast_coalesced(tensors, devices, buffer_size=10485760): A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`. """ devices = [_get_device_index(d) for d in devices] + tensors = [_handle_complex(t) for t in tensors] return torch._C._broadcast_coalesced(tensors, devices, buffer_size) @@ -182,6 +183,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= a tuple containing :attr:`out` tensors, each containing a chunk of :attr:`tensor`. """ + tensor = _handle_complex(tensor) if out is None: devices = [_get_device_index(d) for d in devices] return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) @@ -196,6 +198,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= "but got chunk_sizes={}".format(chunk_sizes)) return tuple(torch._C._scatter_out(tensor, out, dim, streams)) + def gather(tensors, dim=0, destination=None, *, out=None): r"""Gathers tensors from multiple GPU devices. @@ -222,6 +225,7 @@ def gather(tensors, dim=0, destination=None, *, out=None): the :attr:`out` tensor, now containing results of concatenating :attr:`tensors` along :attr:`dim`. """ + tensors = [_handle_complex(t) for t in tensors] if out is None: if destination == -1: warnings.warn( diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 67cb0e1f5dc2..d2d7e5591fb7 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -42,15 +42,61 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM Arguments: model (torch.nn.Module): the model to be exported. - args (tuple of arguments or torch.Tensor): the inputs to - the model, e.g., such that ``model(*args)`` is a valid - invocation of the model. Any non-Tensor arguments (including None) will - be hard-coded into the exported model; any Tensor arguments - will become inputs of the exported model, in the order they - occur in args. If args is a Tensor, this is equivalent - to having called it with a 1-ary tuple of that Tensor. - (Note: passing keyword arguments to the model is not currently - supported. Give us a shout if you need it.) + args (tuple of arguments or torch.Tensor, a dictionary consisting of named arguments (optional)): + a dictionary to specify the input to the corresponding named parameter: + - KEY: str, named parameter + - VALUE: corresponding input + args can be structured either as: + + 1. ONLY A TUPLE OF ARGUMENTS or torch.Tensor:: + + ‘’args = (x, y, z)’' + + The inputs to the model, e.g., such that ``model(*args)`` is a valid invocation + of the model. Any non-Tensor arguments will be hard-coded into the exported model; + any Tensor arguments will become inputs of the exported model, in the order they + occur in args. If args is a Tensor, this is equivalent to having + called it with a 1-ary tuple of that Tensor. + + 2. A TUPLE OF ARGUEMENTS WITH A DICTIONARY OF NAMED PARAMETERS:: + + ‘’args = (x, + { + ‘y’: input_y, + ‘z’: input_z + }) ‘’ + + The inputs to the model are structured as a tuple consisting of + non-keyword arguments and the last value of this tuple being a dictionary + consisting of named parameters and the corresponding inputs as key-value pairs. + If certain named argument is not present in the dictionary, it is assigned + the default value, or None if default value is not provided. + + Cases in which an dictionary input is the last input of the args tuple + would cause a conflict when a dictionary of named parameters is used. + The model below provides such an example. + + class Model(torch.nn.Module): + def forward(self, k, x): + ... + return x + + m = Model() + k = torch.randn(2, 3)   + x = {torch.tensor(1.): torch.randn(2, 3)} + + In the previous iteration, the call to export API would look like + + torch.onnx.export(model, (k, x), ‘test.onnx’) + + This would work as intended. However, the export function + would now assume that the ‘x’ input is intended to represent the optional + dictionary consisting of named arguments. In order to prevent this from being + an issue a constraint is placed to provide an empty dictionary as the last + input in the tuple args in such cases. The new call would look like this. + + torch.onnx.export(model, (k, x, {}), ‘test.onnx’) + f: a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name. A binary Protobuf will be written to this file. diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 5e9430f995f8..4cc3f47a3541 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -2,6 +2,7 @@ import torch import warnings from sys import maxsize as maxsize +from typing import Set import torch.onnx # This import monkey-patches graph manipulation methods on Graph, used for the @@ -125,7 +126,7 @@ def decorator(fn): def wrapper(g, *args, **kwargs): # some args may be optional, so the length may be smaller assert len(arg_descriptors) >= len(args) - args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] + args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore # only support _outputs in kwargs assert len(kwargs) <= 1 if len(kwargs) == 1: @@ -177,6 +178,29 @@ def _is_tensor(x): def _is_tensor_list(x): return isinstance(x.type(), torch._C.ListType) and isinstance(x.type().getElementType(), torch._C.TensorType) +def _get_tensor_rank(x): + if not _is_tensor(x) or x.type() is None: + return None + return x.type().dim() + +def _get_tensor_sizes(x, allow_nonstatic=True): + if not _is_tensor(x) or x.type() is None: + return None + if allow_nonstatic: + # Each individual symbol is returned as None. + # e.g. [1, 'a', 'b'] -> [1, None, None] + return x.type().varyingSizes() + # returns None, if exists any symbol in sizes. + # e.g. [1, 'a', 'b'] -> None + return x.type().sizes() + +def _get_tensor_dim_size(x, dim): + try: + sizes = _get_tensor_sizes(x) + return sizes[dim] + except Exception: + pass + return None def _unimplemented(op, msg): warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported") @@ -215,7 +239,7 @@ def _try_get_scalar_type(*args): def _select_helper(g, self, dim, index, apply_reshape=True): index_const = _maybe_get_scalar(index) - index_dim = index.type().dim() + index_dim = _get_tensor_rank(index) if not _is_value(index_const): # Index is a constant scalar. Make it a size 1 constant tensor. index = g.op("Constant", value_t=torch.LongTensor([index_const])) @@ -232,18 +256,18 @@ def _select_helper(g, self, dim, index, apply_reshape=True): def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if _export_onnx_opset_version <= 9: - from torch.onnx.symbolic_opset9 import _slice - return _slice(g, input, axes, starts, ends) + from torch.onnx.symbolic_opset9 import _slice as _slice9 + return _slice9(g, input, axes, starts, ends) else: - from torch.onnx.symbolic_opset10 import _slice - return _slice(g, input, axes, starts, ends, steps, dynamic_slice) + from torch.onnx.symbolic_opset10 import _slice as _slice10 + return _slice10(g, input, axes, starts, ends, steps, dynamic_slice) def _hardtanh_helper(g, input, min_val, max_val): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import hardtanh return hardtanh(g, input, min_val, max_val) else: - from torch.onnx.symbolic_opset11 import hardtanh + from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef] return hardtanh(g, input, min_val, max_val) def _is_fp(value): @@ -343,7 +367,8 @@ def _get_interpolate_attributes(g, mode, args): def _interpolate_get_scales(g, scale_factor, dim): offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) - if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor.isCompleteTensor() and scale_factor.type().dim() > 0): + scale_factor_rank = _get_tensor_rank(scale_factor) + if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor_rank is not None and scale_factor_rank > 0): return g.op("Concat", offsets, scale_factor, axis_i=0) else: scale_factor = _unsqueeze_helper(g, scale_factor, 0) @@ -380,7 +405,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_ size = g.op("Concat", *size, axis_i=0) scale_factor = _interpolate_size_to_scales(g, input, size, dim) else: - return _unimplemented("Both size and scales are None in __interpolate") + return _unimplemented("interpolate", "Both size and scales are None in __interpolate") return scale_factor, mode @@ -388,7 +413,7 @@ def _unbind_helper(g, self, dim, _outputs): if _export_onnx_opset_version <= 9: from torch.onnx.symbolic_opset9 import unbind else: - from torch.onnx.symbolic_opset11 import unbind + from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] return unbind(g, self, dim, _outputs) @@ -396,7 +421,8 @@ def _scatter_helper(g, self, dim, index, src): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: - from torch.onnx.symbolic_opset11 import scatter + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore return scatter(g, self, dim, index, src) @@ -444,7 +470,8 @@ def _index_fill_reshape_helper(g, self, dim, index): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: - from torch.onnx.symbolic_opset11 import scatter + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore if self.type().dim() is None: return _unimplemented("index_fill", "input rank not accesible") @@ -632,4 +659,4 @@ def _cast_func_template(to_i, g, input, non_blocking): # Global set to store the list of quantized operators in the network. # This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. -_quantized_ops = set() +_quantized_ops: Set[int] = set() diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 718b30f8fde3..6558df6e3d4c 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -209,12 +209,13 @@ def embedding_bag(g, import warnings warnings.warn("Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " "Please use opset 11 or higher to export model for dynamic input shape.'") - if offsets.type().sizes() is not None: + offsets_dim_0 = sym_help._get_tensor_dim_size(offsets, 0) + if offsets_dim_0 is not None: if include_last_offset: - offset_len = offsets.type().sizes()[0] - 1 + offset_len = offsets_dim_0 - 1 offsets_extended = offsets else: - offset_len = offsets.type().sizes()[0] + offset_len = offsets_dim_0 offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))] offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) list_ = [] diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index de2acf6085a0..6e9fe3f27060 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -97,21 +97,21 @@ def index_put(g, self, indices_list_value, values, accumulate=False): # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %15 : None = prim::Constant() - # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : int[] = prim::Constant[value=[-1]]() # %23 : Tensor = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) - # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Tensor = onnx::Equal(%0, %some_const) - # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) + # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) # %19 : Tensor = onnx::Cast[to=9](%12) # %20 : Tensor = onnx::Constant[value={1}]() @@ -137,7 +137,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False): # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : None = prim::Constant() # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) - # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Tensor = aten::view(%23, %30) @@ -148,7 +148,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False): # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): - # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) + # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %4 : Tensor = onnx::Equal(%0, %some_const) # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) @@ -168,17 +168,17 @@ def index_put(g, self, indices_list_value, values, accumulate=False): # %32 : Tensor = onnx::Constant[value={0}]() # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) - # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::ScatterND(%0, %22, %34) # return (%35) bool_inp = list(index.node().inputs())[0] if bool_inp.type() is not None and bool_inp.type().scalarType() == 'Bool': - if values.type() is not None: - if values.type().dim() == 0: - from torch.onnx.symbolic_opset9 import masked_fill - return masked_fill(g, self, bool_inp, values) - return masked_scatter(g, self, bool_inp, values) + rank = sym_help._get_tensor_rank(values) + if rank is not None and rank == 0: + from torch.onnx.symbolic_opset9 import masked_fill + return masked_fill(g, self, bool_inp, values) + return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper( @@ -201,8 +201,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False): @parse_args('v', 'i') def pixel_shuffle(g, self, upscale_factor): - dims = self.type().sizes() - if len(dims) != 4: + rank = sym_help._get_tensor_rank(self) + if rank is not None and rank != 4: return _unimplemented("pixel_shuffle", "only support 4d input") return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") @@ -280,11 +280,12 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_s "while exporting interpolate. Assuming that it is not a scalar.") if is_scalar: - if not input.type().dim(): + rank = sym_help._get_tensor_rank(input) + if rank is None: return sym_help._unimplemented("interpolate (with a scalar output_size)", "missing input shape (try giving an array of output_size values)") size = unsqueeze(g, size, 0) - size = [size for i in range(input.type().dim() - 2)] + size = [size for i in range(rank - 2)] size = g.op("Concat", *size, axis_i=0) size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long']) size = g.op("Concat", input_size, size, axis_i=0) @@ -299,9 +300,10 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_s mode_s=mode, # nearest, linear, or cubic nearest_mode_s="floor") else: # if not sym_help._is_none(scales) - if not input.type().dim(): + rank = sym_help._get_tensor_rank(input) + if rank is None: return sym_help._unimplemented("interpolate (with scales)", "missing input shape") - scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim()) + scales = sym_help._interpolate_get_scales(g, scale_factor, rank) return g.op("Resize", input, roi, @@ -549,19 +551,19 @@ def constant_pad_nd(g, input, padding, value=None): mode = "constant" value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, input) - pad = _prepare_onnx_paddings(g, input.type().dim(), padding) + pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pad, value, mode_s=mode) def reflection_pad(g, input, padding): mode = "reflect" - paddings = _prepare_onnx_paddings(g, input.type().dim(), padding) + paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, paddings, mode_s=mode) def replication_pad(g, input, padding): mode = "edge" - paddings = _prepare_onnx_paddings(g, input.type().dim(), padding) + paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, paddings, mode_s=mode) @@ -639,9 +641,12 @@ def squeeze(g, self, dim=None): dim = sym_help._get_const(dim, 'i', 'dim') - input_shape = self.type().sizes() - from torch.onnx.symbolic_helper import _onnx_shape_inference - if input_shape is None or not _onnx_shape_inference: + input_rank = sym_help._get_tensor_rank(self) + adjusted_dim = dim + if input_rank is not None and dim < 0: + adjusted_dim += input_rank + dim_size = sym_help._get_tensor_dim_size(self, adjusted_dim) + if (dim < 0 and input_rank is None) or dim_size is None: # If onnx shape inference is not on, export always as dynamic. # Because we cannot tell if observed static shape is also static at runtime. # create 'cond' node (condition is shape[i]==1) @@ -661,11 +666,10 @@ def squeeze(g, self, dim=None): return if_node_outputs # For static input shape - if dim < 0: - dim += self.type().dim() - if input_shape[dim] > 1: + dim = adjusted_dim + if dim_size > 1: warnings.warn("This model contains a squeeze operation on dimension " + str(dim) + ". The size of " + - "this dimension in the given input is " + str(input_shape[dim]) + ". The model will " + + "this dimension in the given input is " + str(dim_size) + ". The model will " + "be exported without the squeeze node. If the model is intended to be used with dynamic " + "input shapes, please export with dynamic_axes argument.") return self @@ -861,7 +865,7 @@ def narrow(g, input, dim, start, length): @parse_args('v', 'i', 'i') def flatten(g, input, start_dim, end_dim): - dim = input.type().dim() + dim = sym_help._get_tensor_rank(input) # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim == 1: if (end_dim == -1 or (dim is not None and end_dim == dim - 1)): diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index c0c1d48ebec0..1fa9fa5e985b 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -4,7 +4,7 @@ import torch.onnx.symbolic_opset9 as sym_opset9 from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type -from torch.onnx.symbolic_opset9 import _cast_Float +from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore import warnings @@ -148,10 +148,9 @@ def matmul(g, self, other): def prelu(g, self, weight): - if self.isCompleteTensor(): - self_sizes = self.type().sizes() - if self_sizes and len(self_sizes) > 2: - weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1))) + self_rank = sym_help._get_tensor_rank(self) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) if _try_get_scalar_type(self): old_type, self, weight = _try_cast_integer_to_float(g, self, weight) return _cast_to_type(g, g.op("PRelu", self, weight), old_type) @@ -267,7 +266,7 @@ def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, mem def repeat(g, self, repeats): if not sym_help._is_value(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) - if sym_help._is_packed_list(repeats): + if sym_help._is_packed_list(repeats): repeat_size_len = len(sym_help._unpack_list(repeats)) else: const_repeats = sym_help._maybe_get_const(repeats, 'is') diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index e395ce5c703f..bda62b638d22 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -13,6 +13,8 @@ import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented +from typing import Optional + import numpy import math import warnings @@ -210,9 +212,9 @@ def matmul(g, self, other): @parse_args('v', 'v', 'v', 't', 't') def addmm(g, self, mat1, mat2, beta, alpha): dtype = None - self_dtype = self.type().scalarType() - mat1_dtype = mat1.type().scalarType() - mat2_dtype = mat2.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + mat1_dtype = sym_help._try_get_scalar_type(mat1) + mat2_dtype = sym_help._try_get_scalar_type(mat2) if self_dtype is not None: dtype = self_dtype elif mat1_dtype is not None: @@ -220,8 +222,8 @@ def addmm(g, self, mat1, mat2, beta, alpha): elif mat2_dtype is not None: dtype = mat2_dtype - mat1_rank = mat1.type().dim() - mat2_rank = mat2.type().dim() + mat1_rank = sym_help._get_tensor_rank(mat1) + mat2_rank = sym_help._get_tensor_rank(mat2) def isNotNoneAnd(v, u): return v is not None and v != u @@ -311,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self): if dtype is not None: # pytorch reduce-ops cast all other integral types to int64 if not sym_help._is_fp(self) and not (dtype == 'Long'): - self = _cast_Long(g, self, False) + self = _cast_Long(g, self, False) # type: ignore return self @@ -461,8 +463,8 @@ def size(g, self, dim=None): if dim is None: return g.op("Shape", self) if sym_help._maybe_get_const(dim, 'i') < 0: - rank = self.type().dim() - if rank: + rank = sym_help._get_tensor_rank(self) + if rank is not None: dim = sym_help._maybe_get_const(dim, 'i') + rank dim = g.op("Constant", value_t=torch.tensor(dim)) return sym_help._size_helper(g, self, dim) @@ -474,8 +476,9 @@ def transpose(g, self, dim0, dim1): return self # NB: Transpose in ONNX is actually a Permute - if self.isCompleteTensor(): - axes = list(range(self.type().dim())) + rank = sym_help._get_tensor_rank(self) + if rank is not None: + axes = list(range(rank)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) else: @@ -510,7 +513,9 @@ def view_as(g, self, other): def prim_ConstantSplit(g, self, split_size, dim): - size = self.type().sizes()[dim] + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return _unimplemented('prim::ConstantSplit', 'unknown dimension size') splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: @@ -523,7 +528,10 @@ def prim_ConstantSplit(g, self, split_size, dim): # TODO: Once we have proper scoping, stop reimplementing chunk, delete this # method, and use the desugared version def prim_ConstantChunk(g, self, chunks, dim): - split_size = (self.type().sizes()[dim] + chunks - 1) // chunks + dim_size = sym_help._get_tensor_dim_size(self, dim) + if dim_size is None: + return _unimplemented('prim::ConstantChunk', 'unknown dimension size') + split_size = (dim_size + chunks - 1) // chunks return prim_ConstantSplit(g, self, split_size, dim) @@ -531,8 +539,10 @@ def prim_ConstantChunk(g, self, chunks, dim): def unsafe_chunk(g, self, chunks, dim, _outputs=None): if _outputs is None: return sym_help._onnx_opset_unsupported_detailed('unsafe_chunk', 9, 11, 'Dynamic number of outputs not supported') - split_size = (self.type().sizes()[dim] + chunks - 1) // chunks - size = self.type().sizes()[dim] + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return _unimplemented('unsafe_chunk', 'unknown dimension size') + split_size = (size + chunks - 1) // chunks splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: @@ -550,7 +560,9 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size') dim = sym_help._get_const(dim, 'i', 'dim') - size = self.type().sizes()[dim] + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return sym_help._onnx_opset_unsupported_detailed('split', 9, 11, 'Unknown dimension size not supported') splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: @@ -605,8 +617,8 @@ def squeeze(g, self, dim=None): squeeze_dim = sym_help._get_const(dim, 'i', 'dim') # Handle negative dims if squeeze_dim < 0: - rank = self.type().dim() - if rank: + rank = sym_help._get_tensor_rank(self) + if rank is not None: warnings.warn("ONNX export squeeze with negative axis " + str(squeeze_dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + @@ -617,17 +629,17 @@ def squeeze(g, self, dim=None): else: return _unimplemented('squeeze', 'negative axis with unknown input rank') - input_shape = self.type().sizes() - if input_shape is None: + dim_size = sym_help._get_tensor_dim_size(self, squeeze_dim) + if dim_size is None: warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " + "with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + "non-singleton dimensions, it is recommended to export this model using opset " + "version 11 or higher.") return g.op("Squeeze", self, axes_i=[squeeze_dim]) - if input_shape[squeeze_dim] > 1: + if dim_size > 1: warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " + - "this dimension in the given input is " + str(input_shape[squeeze_dim]) + ". The model will " + + "this dimension in the given input is " + str(dim_size) + ". The model will " + "be exported without the squeeze node. If the model is intended to be used with dynamic " + "input shapes, please use opset version 11 to " + "export the model.") @@ -638,10 +650,9 @@ def squeeze(g, self, dim=None): return g.op("Squeeze", self, axes_i=[squeeze_dim]) def prelu(g, self, weight): - if self.isCompleteTensor(): - self_sizes = self.type().sizes() - if self_sizes and len(self_sizes) > 2: - weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1))) + self_rank = sym_help._get_tensor_rank(self) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) return g.op("PRelu", self, weight) @@ -681,7 +692,9 @@ def leaky_relu(g, input, negative_slope, inplace=False): @parse_args('v', 'i') def glu(g, input, dim): - assert input.type().sizes()[dim] % 2 == 0 + dim_size = sym_help._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 first, second = g.op('Split', input, axis_i=dim, outputs=2) return g.op('Mul', first, g.op('Sigmoid', second)) @@ -709,7 +722,7 @@ def softmax(g, input, dim, dtype=None): # otherwise transpose the input to put the vectors to be normalized to the last dimension. # When input rank is not known at export time we compute softmax using a subgraph # with other operators - input_dim = input.type().dim() + input_dim = sym_help._get_tensor_rank(input) if input_dim is not None: # TODO: remove this as onnx opset 11 spec allows negative axes if dim < 0: @@ -751,7 +764,10 @@ def softplus(g, self, beta, threshold): def get_pool_ceil_padding(input, kernel_size, stride, padding): - dim = input.type().sizes()[-len(padding):] + sizes = sym_help._get_tensor_sizes(input) + dim = sizes[-len(padding):] if sizes is not None else None + if dim is None or any([i is None for i in dim]): + return _unimplemented(name, "input size not accessible") ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1 for i in range(0, len(padding))] # ensure last pooling starts inside @@ -776,8 +792,6 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding): def _max_pool(name, tuple_fn, ndims, return_indices): @parse_args('v', 'is', 'is', 'is', 'is', 'i') def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): - if ceil_mode and not input.isCompleteTensor(): - return _unimplemented(name, "input size not accessible") if set(tuple_fn(dilation)) != {1}: return _unimplemented(name, "dilation") if not stride: @@ -834,8 +848,6 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): def _avg_pool(name, tuple_fn): @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none') def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None): - if ceil_mode and not input.isCompleteTensor(): - return _unimplemented(name, "input size not accessible") if not stride: stride = kernel_size padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name) @@ -881,11 +893,15 @@ def symbolic_fn(g, input, output_size): return sym_help._onnx_unsupported('adaptive pooling, since output_size is not constant.') if output_size == [1] * len(output_size) and type == "AveragePool": return g.op("GlobalAveragePool", input) - if not input.isCompleteTensor(): + sizes = sym_help._get_tensor_sizes(input) + try: + dim = sizes[2:] + except Exception: + dim = None + if dim is None or any([i is None for i in dim]): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return _unimplemented(name, 'input size not accessible') - dim = input.type().sizes()[2:] # verify if output size % input size = 0 for all dim mod = [dim[i] % output_size[i] for i in range(0, len(dim))] if mod != [0] * len(mod): @@ -949,21 +965,21 @@ def constant_pad_nd(g, input, padding, value): return sym_help._onnx_opset_unsupported_detailed('Pad', 9, 11, 'The value for the padding must be constant') padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(input.type().dim(), padding) + paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value) def reflection_pad(g, input, padding): mode = "reflect" padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(input.type().dim(), padding) + paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pads_i=paddings, mode_s=mode) def replication_pad(g, input, padding): mode = "edge" padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(input.type().dim(), padding) + paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pads_i=paddings, mode_s=mode) @@ -1133,7 +1149,7 @@ def log_softmax(g, input, dim, dtype=None): # PyTorch dim and ONNX axis have different meanings. # See Softmax comment for details. # TODO: remove this as onnx opset 11 spec allows negative axes - input_dim = input.type().dim() + input_dim = sym_help._get_tensor_rank(input) if input_dim is None: return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input. " @@ -1159,11 +1175,19 @@ def log_softmax(g, input, dim, dtype=None): @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i', 'i') def _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32): - weight_size = weight.type().sizes() + weight_size = sym_help._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + kernel_shape = None + + if kernel_shape is None or any([i is None for i in kernel_shape]): + raise RuntimeError('Unsupported: ONNX export of convolution for kernel ' + 'of unknown shape.') args = [input, weight] # ONNX only supports 1D bias - if not sym_help._is_none(bias) and bias.type().dim() == 1: + if not sym_help._is_none(bias) and sym_help._get_tensor_rank(bias) == 1: args.append(bias) kwargs = {"kernel_shape_i": weight_size[2:], @@ -1184,7 +1208,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation, n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) - if not sym_help._is_none(bias) and bias.type().dim() != 1: + if not sym_help._is_none(bias) and sym_help._get_tensor_rank(bias) != 1: return g.op("Add", n, bias) else: return n @@ -1223,26 +1247,31 @@ def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, gr @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): sym_help.assert_training_mode(training, "batch_norm") - input_sizes = input.type().sizes() + batch_size = sym_help._get_tensor_dim_size(input, 0) + channel_size = sym_help._get_tensor_dim_size(input, 1) if weight is None or sym_help._is_none(weight): - assert len(input_sizes) > 1 - weight_value = torch.tensor([1.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of batch_norm for unknown ' + 'channel size.') + weight_value = torch.tensor([1.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') weight = g.op("Constant", value_t=weight_value) if bias is None or sym_help._is_none(bias): - assert len(input_sizes) > 1 - bias_value = torch.tensor([0.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of batch_norm for unknown ' + 'channel size.') + bias_value = torch.tensor([0.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') bias = g.op("Constant", value_t=bias_value) - # If track_running_stats is set to False batch statistics are instead used during evaluation time + # If track_running_stats is set to False batch statistics are instead used during evaluation time if running_mean is None or sym_help._is_none(running_mean) or running_var is None or sym_help._is_none(running_var): - assert len(input_sizes) > 1 - reshape_in = g.op("Reshape", input, - g.op("Constant", value_t=torch.tensor([input_sizes[0], input_sizes[1], -1], dtype=torch.int64))) + assert batch_size is not None and channel_size is not None + reshape_in = g.op("Reshape", input, + g.op("Constant", value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64))) trans_in = g.op('Transpose', reshape_in, perm_i=[0, 2, 1]) - running_var, running_mean = _var_mean(g, trans_in, - g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + running_var, running_mean = _var_mean(g, trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), False, False) out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, @@ -1288,15 +1317,19 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled): - input_sizes = input.type().sizes() + channel_size = sym_help._get_tensor_dim_size(input, 1) if weight is None or sym_help._is_none(weight): - assert len(input_sizes) > 1 - weight_value = torch.tensor([1.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of instance_norm for unknown ' + 'channel size.') + weight_value = torch.tensor([1.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') weight = g.op("Constant", value_t=weight_value) if bias is None or sym_help._is_none(bias): - assert len(input_sizes) > 1 - bias_value = torch.tensor([0.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of instance_norm for unknown ' + 'channel size.') + bias_value = torch.tensor([0.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') bias = g.op("Constant", value_t=bias_value) return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) @@ -1306,13 +1339,17 @@ def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_s def unfold(g, input, dimension, size, step): if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) - if input.isCompleteTensor(): - sizedim = input.type().sizes()[dimension] + sizes = sym_help._get_tensor_sizes(input) + try: + sizedim = sizes[dimension] + except Exception: + sizedim = None + if sizedim is not None: low_indices = range(0, sizedim, step) hi_indices = range(size, sizedim + 1, step) stack = [sym_help._slice_helper(g, input, axes=[dimension], starts=[low], ends=[hi]) for low, hi in zip(low_indices, hi_indices)] - ndim = input.type().dim() + ndim = len(sizes) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze = [g.op("Unsqueeze", g.op("Transpose", t, perm_i=perm), axes_i=[dimension]) for t in stack] @@ -1373,11 +1410,12 @@ def index_copy(g, self, dim, index, source): def type_as(g, self, other): - if self.isCompleteTensor() and other.isCompleteTensor() and self.type().scalarType() == other.type().scalarType(): + self_dtype = sym_help._try_get_scalar_type(self) + other_dtype = sym_help._try_get_scalar_type(other) + if self_dtype == other_dtype and self_dtype is not None: return self - if other.isCompleteTensor(): - other_type_name = other.type().scalarType() - return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_type_name]) + if other_dtype is not None: + return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_dtype]) else: if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: # We don't know the type of other, bail by emitting ATen @@ -1573,8 +1611,9 @@ def empty_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False): - if dtype is None and self.isCompleteTensor(): - dtype = self.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + if dtype is None and self_dtype is not None: + dtype = self_dtype dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return empty(g, sizes, dtype, layout, device, pin_memory) @@ -1626,8 +1665,9 @@ def zeros_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False): - if dtype is None and self.isCompleteTensor(): - dtype = self.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + if dtype is None and self_dtype is not None: + dtype = self_dtype dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return zeros(g, sizes, dtype, layout, device, pin_memory) @@ -1677,8 +1717,9 @@ def full_like(g, input, fill_value, dtype=None, layout=None, device=None, pin_me def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False): - if dtype is None and self.isCompleteTensor(): - dtype = self.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + if dtype is None and self_dtype is not None: + dtype = self_dtype dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return full(g, size, fill_value, dtype, layout, device, pin_memory) @@ -1743,8 +1784,8 @@ def hardtanh(g, self, min_val, max_val): @parse_args('v') def hardswish(g, self): input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) - hardtanh_ = sym_help._hardtanh_helper(g, input, - g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), + hardtanh_ = sym_help._hardtanh_helper(g, input, + g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) return g.op("Mul", self, hardtanh_) @@ -1757,8 +1798,8 @@ def alias(g, self): def unsqueeze(g, self, dim): # Handle negative dim if dim < 0: - rank = self.type().dim() - if rank: + rank = sym_help._get_tensor_rank(self) + if rank is not None: warnings.warn("ONNX export unsqueeze with negative axis " + str(dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + @@ -1776,10 +1817,16 @@ def unsqueeze(g, self, dim): def sort(g, self, dim, decending, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported for sort") - if not self.isCompleteTensor(): + self_sizes = sym_help._get_tensor_sizes(self) + try: + dim_size = self_sizes[dim] + except Exception: + dim_size = None + + if dim_size is None: return _unimplemented("Sort", "input size not accessible") - return g.op("TopK", self, k_i=self.type().sizes()[dim], axis_i=dim, outputs=2) + return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) def numel(g, self): @@ -1842,9 +1889,11 @@ def repeat(g, self, repeats): @parse_args('v', 'i') def pixel_shuffle(g, self, upscale_factor): - dims = self.type().sizes() + dims = sym_help._get_tensor_sizes(self) if len(dims) != 4: return _unimplemented("pixel_shuffle", "only support 4d input") + if any([i is None for i in dims[1:]]): + return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size") output_channel = dims[1] // upscale_factor // upscale_factor after_view = view(g, self, g.op("Constant", value_t=torch.tensor([-1, output_channel, upscale_factor, upscale_factor, dims[2], dims[3]]))) @@ -1880,7 +1929,9 @@ def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, variant = 'RNN' w_hh = all_weights[1] - hidden_size = w_hh.type().sizes()[1] + hidden_size = sym_help._get_tensor_dim_size(w_hh, 1) + if hidden_size is None: + return _unimplemented("RNN/GRU/LSTM", "unknown hidden size") unidirectional = not bidirectional @@ -2092,7 +2143,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first): # It's really only necessary because those operators expand to something that # only works with int32 types in Caffe2... if lengths.type().scalarType() != 'Int': - lengths = _cast_Int(g, lengths, False) + lengths = _cast_Int(g, lengths, False) # type: ignore return g.op("prim::PackPadded", input, lengths, outputs=2) @@ -2164,7 +2215,7 @@ def erf(g, input): @parse_args('v', 'i', 'i') def flatten(g, input, start_dim, end_dim): - dim = input.type().dim() + dim = sym_help._get_tensor_rank(input) if dim is None: return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input. " @@ -2239,13 +2290,16 @@ def scatter(g, self, dim, index, src): @parse_args('v', 'i', 'v', 'v') def scatter_add(g, self, dim, index, src): - if not self.isCompleteTensor(): - return _unimplemented("scatter_add", "input size not accessible") - dtype = self.type().scalarType() + dtype = sym_help._try_get_scalar_type(self) + if dtype is None: + return _unimplemented("scatter_add", "input dtype not accessible") dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] - sizes = self.type().sizes() - to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype)) + sizes = sym_help._get_tensor_sizes(self, allow_nonstatic=False) + if sizes: + to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype)) + else: + to_add = zeros_like(self, dtype) to_add = sym_help._scatter_helper(g, to_add, dim, index, src) return add(g, self, to_add) @@ -2436,7 +2490,7 @@ def _get_arange_dtype(dtype): def masked_fill(g, self, mask, value): - mask = _cast_Bool(g, mask, False) + mask = _cast_Bool(g, mask, False) # type: ignore value = sym_help._maybe_get_scalar(value) return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self) @@ -2489,7 +2543,7 @@ def try_mask_to_index(index): elif len(adv_idx_indices) == 1: return index_select(g, self, adv_idx_indices[0], indices[adv_idx_indices[0]]) else: - rank = self.type().dim() + rank = sym_help._get_tensor_rank(self) if rank is None: raise NotImplementedError("Unsupported aten::index operator of advanced indexing on tensor of unknown rank, " + "try turning on shape and type propagate during export: " + @@ -2501,7 +2555,6 @@ def try_mask_to_index(index): " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results.") - rank = self.type().dim() adv_idx_count = len(adv_idx_indices) shape_tensor = _shape_as_tensor(g, self) dim_tensor_list = [ @@ -2620,8 +2673,12 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): return g.op("ATen", input, weight, bias, num_groups_i=num_groups, eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm") - input_sizes = input.type().sizes() - assert input_sizes[1] % num_groups == 0 + channel_size = sym_help._get_tensor_dim_size(input, 1) + if channel_size is not None: + assert channel_size % num_groups == 0 + input_rank = sym_help._get_tensor_rank(input) + if input_rank is None: + return _unimplemented("group_norm", "unknown input rank") # 0 in the shape list keeps dimension value unchanged. shape = [0, num_groups, -1] input_reshaped = g.op('Reshape', input, g.op('Constant', value_t=torch.LongTensor(shape))) @@ -2647,14 +2704,14 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): bias = g.op("Constant", value_t=bias_value) # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] - axes = list(range(1, len(input_sizes) - 1)) + axes = list(range(1, input_rank - 1)) return add(g, mul(g, norm, g.op("Unsqueeze", weight, axes_i=axes)), g.op("Unsqueeze", bias, axes_i=axes)) @parse_args('v', 'v', 'i') def _weight_norm(g, weight_v, weight_g, dim): - rank = weight_v.type().dim() - if rank: + rank = sym_help._get_tensor_rank(weight_v) + if rank is not None: # W = g * ((v) / ||v||) # Compute norm_except_dim for l2 norm. dim = None means over all dims # torch's weight_norm module sets dim = -1 if it's None. @@ -2734,6 +2791,7 @@ def as_strided(g, self, sizes, strides, offset=None): sizes = sym_help._maybe_get_const(sizes, 'is') rank = len(strides) self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + ind: Optional[torch.Tensor] if not sym_help._is_value(sizes): ind = torch.tensor([0], dtype=torch.long) for i, (size, stride) in enumerate(zip(sizes, strides)): diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py index 48114d6c472b..c059e8f2eb31 100644 --- a/torch/onnx/symbolic_registry.py +++ b/torch/onnx/symbolic_registry.py @@ -1,6 +1,7 @@ import warnings import importlib from inspect import getmembers, isfunction +from typing import Dict, Tuple, Any, Union # The symbolic registry "_registry" is a dictionary that maps operators # (for a specific domain and opset version) to their symbolic functions. @@ -8,9 +9,9 @@ # The keys are tuples (domain, version), (where domain is a string, and version is an int), # and the operator's name (string). # The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic -_registry = {} +_registry: Dict[Tuple[str, int], Dict] = {} -_symbolic_versions = {} +_symbolic_versions: Dict[Union[int, str], Any] = {} from torch.onnx.symbolic_helper import _onnx_stable_opsets for opset_version in _onnx_stable_opsets: module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version)) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 5c41306b9ee2..479f874819f2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -18,6 +18,7 @@ from torch.jit import _unique_state_dict from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto +from typing import Union, Tuple, List # the flag to tell the user whether it's in the middle of ONNX export or not @@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None, if aten or export_raw_ir: assert operator_export_type is None assert aten ^ export_raw_ir - operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW + operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW elif operator_export_type is None: if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK @@ -316,6 +317,35 @@ def _decide_external_data_format(use_external_data_format, operator_export_type, model_file_location = f if val_use_external_data_format and isinstance(f, str) else str() return val_use_external_data_format, model_file_location +def _decide_input_format(model, args): + import inspect + try: + sig = inspect.signature(model.forward) + ordered_list_keys = list(sig.parameters.keys()) + if isinstance(args[-1], dict): + args_dict = args[-1] + args = list(args)[:-1] + n_nonkeyword = len(args) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default is param.empty: + args.append(None) + else: + args.append(param.default) + args = tuple(args) + return args + # Cases of models without forward functions and dict inputs + except AttributeError: + warnings.warn("Model has no forward function") + return args + # Cases of models with no input args + except IndexError: + warnings.warn("No input args") + return args def _trace(func, args, operator_export_type, return_outs=False): # Special case for common case of passing a single Tensor @@ -351,6 +381,7 @@ def _trace_and_get_graph_from_model(model, args): def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes): torch_out = None + params: Union[List, Tuple] if isinstance(model, torch.jit.ScriptModule): try: graph = model.forward.graph @@ -442,7 +473,7 @@ def _model_to_graph(model, args, verbose=False, param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) - if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training): + if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions: @@ -476,7 +507,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t if aten or export_raw_ir: assert operator_export_type is None assert aten ^ export_raw_ir - operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW + operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW elif operator_export_type is None: operator_export_type = OperatorExportTypes.ONNX return _export_to_pretty_string(model, args, f, export_params, verbose, training, @@ -512,6 +543,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, opset_version) val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type) val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type, training) + args = _decide_input_format(model, args) graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, @@ -562,6 +594,7 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini # in ONNX, fall through will occur and export the operator as is, as a custom ONNX op. operator_export_type = OperatorExportTypes.ONNX_FALLTHROUGH with select_model_mode_for_export(model, training): + args = _decide_input_format(model, args) graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type) # The output 'unsupported_ops' will contain the names of all the ops that are not supported in ONNX @@ -627,6 +660,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None, val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format, operator_export_type, f) + args = _decide_input_format(model, args) if dynamic_axes is None: dynamic_axes = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) @@ -1051,6 +1085,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs): dims = [1] isscalar = True type = type.lower() + tensor: Union[torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor, + torch.HalfTensor, torch.FloatTensor, + torch.DoubleTensor] if type == "char": tensor = torch.CharTensor(*dims) elif type == "short": @@ -1068,7 +1106,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs): else: raise ValueError("Unknown type, type should be one of the following strings: " "char, short, int, long, half, float, double") - tensor.fill_(value) + tensor.fill_(value) # type: ignore if isscalar: return g.op("Constant", *args, value_z=tensor, **kwargs) return g.op("Constant", *args, value_t=tensor, **kwargs) @@ -1141,8 +1179,8 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): dynamic_axes[key] = value_dict -torch._C.Graph.op = _graph_op -torch._C.Graph.at = _graph_at -torch._C.Block.op = _block_op -torch._C.Graph.constant = _graph_constant -torch._C.Node.__getitem__ = _node_getitem +torch._C.Graph.op = _graph_op # type: ignore +torch._C.Graph.at = _graph_at # type: ignore +torch._C.Block.op = _block_op # type: ignore +torch._C.Graph.constant = _graph_constant # type: ignore +torch._C.Node.__getitem__ = _node_getitem # type: ignore diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 7d413b959415..0a302008cd22 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -5,6 +5,7 @@ from copy import deepcopy from itertools import chain import warnings +import functools class _RequiredParameter(object): @@ -34,6 +35,8 @@ def __init__(self, params, defaults): torch._C._log_api_usage_once("python.optimizer") self.defaults = defaults + self._hook_for_profile() + if isinstance(params, torch.Tensor): raise TypeError("params argument given to the optimizer should be " "an iterable of Tensors or dicts, but got " + @@ -60,6 +63,7 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) + self._hook_for_profile() # To support multiprocessing pickle/unpickle. def __repr__(self): format_string = self.__class__.__name__ + ' (' @@ -72,6 +76,24 @@ def __repr__(self): format_string += ')' return format_string + def _hook_for_profile(self): + self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) + + def profile_hook_step(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + obj, *_ = args + profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) + with torch.autograd.profiler.record_function(profile_name): + return func(*args, **kwargs) + return wrapper + + hooked = getattr(self.__class__.step, "hooked", None) + if not hooked: + self.__class__.step = profile_hook_step(self.__class__.step) + self.__class__.step.hooked = True + def state_dict(self): r"""Returns the state of the optimizer as a :class:`dict`. @@ -179,17 +201,20 @@ def zero_grad(self, set_to_none: bool = False): (in one case it does the step with a gradient of 0 and in the other it skips the step altogether). """ - for group in self.param_groups: - for p in group['params']: - if p.grad is not None: - if set_to_none: - p.grad = None - else: - if p.grad.grad_fn is not None: - p.grad.detach_() + if not hasattr(self, "_zero_grad_profile_name"): + self._hook_for_profile() + with torch.autograd.profiler.record_function(self._zero_grad_profile_name): + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None else: - p.grad.requires_grad_(False) - p.grad.zero_() + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() def step(self, closure): r"""Performs a single optimization step (parameter update). diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index e1315e370269..909aa0c6cc62 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -32,6 +32,8 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + params = list(params) + sparse_params = [] for index, param in enumerate(params): if isinstance(param, dict): diff --git a/torch/overrides.py b/torch/overrides.py index e8a3933a1954..2af6e36ea914 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -505,7 +505,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.lt: lambda input, other, out=None: -1, torch.less: lambda input, other, out=None: -1, torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, - torch.lu_solve: lambda input, LU_data, LU_pivots, out=None: -1, + torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, torch.masked_fill: lambda input, mask, value: -1, torch.masked_scatter: lambda input, mask, source: -1, diff --git a/torch/package/exporter.py b/torch/package/exporter.py index 7395ec96ccd3..11404039d546 100644 --- a/torch/package/exporter.py +++ b/torch/package/exporter.py @@ -1,5 +1,5 @@ import torch -from torch.serialization import normalize_storage_type, location_tag, _should_read_directly +from torch.serialization import normalize_storage_type, location_tag import io import pickletools from .find_file_dependencies import find_files_source_depends_on @@ -7,7 +7,7 @@ from ._importlib import _normalize_path import types import importlib -from typing import List, Any, Callable, Dict, Tuple +from typing import List, Any, Callable, Dict, Tuple, Union, Iterable from distutils.sysconfig import get_python_lib from pathlib import Path import linecache @@ -192,7 +192,7 @@ def _get_source_of_module(self, module: types.ModuleType) -> str: if result is None: extra = '' if self.verbose: - extra = f' See the dependency graph for more info: {self._write_dep_graph(module.__name__)}' + extra = f' See the dependency graph for more info: \n{self._write_dep_graph(module.__name__)}' raise ValueError(f'cannot save source for module "{module.__name__}" because ' f'its source file "{filename}" could not be found.{extra}') return ''.join(result) @@ -211,7 +211,7 @@ def require_module(self, module_name: str, dependencies=True): of modules""" for pattern, action in self.patterns: - if pattern.fullmatch(module_name): + if pattern.matches(module_name): action(module_name) return @@ -220,7 +220,7 @@ def require_module(self, module_name: str, dependencies=True): if self.verbose: print(f'implicitly adding {root_name} to external modules ' f'since it is part of the standard library and is a dependency.') - self.extern_module(root_name) + self.save_extern_module(root_name) return self.save_module(module_name, dependencies) @@ -303,70 +303,64 @@ def save_binary(self, package, resource, binary: bytes): filename = self._filename(package, resource) self._write(filename, binary) - def extern_module(self, module_name: str): + def mock(self, include: 'GlobPattern', *, exclude: 'GlobPattern' = ()): + """Replace some required modules with a mock implementation. Mocked modules will return a fake + object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes + find files that are imported by model files but whose functionality is never used + (e.g. custom serialization code or training helpers). + Use this function to mock this functionality out without having to modify the original code. + + Args: + include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings + for the names of the modules to be mocked out. Strings can also be a glob-style pattern + string that may match multiple modules. Any required dependencies that match this pattern + string will be mocked out automatically. + + Examples: + 'torch.**' -- matches torch and all submodules of torch, e.g. 'torch.nn' and torch.nn.functional' + 'torch.*' -- matches 'torch.nn' or 'torch.functional', but not 'torch.nn.functional' + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + e.g. include='torch.**', exclude='torch.foo' will mock all torch packages except 'torch.foo' Default: [] + + """ + self.patterns.append((_GlobGroup(include, exclude), self.save_mock_module)) + + def extern(self, include: 'GlobPattern', *, exclude: 'GlobPattern' = ()): """Include `module` in the list of external modules the package can import. This will prevent dependency discover from saving it in the package. The importer will load an external module directly from the standard import system. Code for extern modules must also exist in the process loading the package. Args: - module_name (str): e.g. "my_package.my_subpackage" the name of the external module. - This can also be a glob-style pattern, as described in :meth:`mock_module` - """ - if self._add_if_pattern(module_name, self.extern_module): - return - - if module_name not in self.external: - self.external.append(module_name) + include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock` - def extern_modules(self, module_names: List[str]): - """Extern a list of modules. Convience wrapper for calling :meth:`extern_module` on many items. + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. - Args: - module_names (List[str]): List of module names """ - for m in module_names: - self.extern_module(m) + self.patterns.append((_GlobGroup(include, exclude), self.save_extern_module)) - def mock_module(self, module_name: str): - """Replace the code for `module_name` in the package with a fake implementation. This module will return a fake - object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes - find files that are imported by model files but whose functionality is never used - (e.g. custom serialization code or training helpers). - Use this function to mock this functionality out without having to modify the original code. + def save_extern_module(self, module_name: str): + """Add `module_name` to the list of external modules, regardless of whether it is + required by other modules. - Args: - module_name (str): e.g. "my_package.my_subpackage" the name of the module to be mocked out. - The module_name can also be a glob-style pattern string that may match multiple modules. - Any required dependencies that match this pattern string will be mocked out automatically. - Examples: - 'torch.**' -- matches all submodules of torch, e.g. 'torch.nn' and torch.nn.functional' - 'torch.*' -- matches 'torch.nn' or 'torch.functional', but not 'torch.nn.functional' + Prefer using `extern` to only mark modules extern if they are actually required by the packaged code. """ - if self._add_if_pattern(module_name, self.mock_module): - return + if module_name not in self.external: + self.external.append(module_name) + def save_mock_module(self, module_name: str): + """Add `module_name` to the package, implemented it with a mocked out version that + can be imported but does not include any implementations. + + Prefer using `mock` to only include this module if it is required by other modules. + """ if '_mock' not in self.provided: self.save_source_file('_mock', str(Path(__file__).parent / '_mock.py'), dependencies=False) is_package = hasattr(self._import_module(module_name), '__path__') self.save_source_string(module_name, _MOCK_IMPL, is_package, dependencies=False) - - def mock_modules(self, module_names): - """Mock a list of modules. Convience wrapper for calling :meth:`mock_module` on many items. - - Args: - module_names (List[str]): List of module names - """ - for module_name in module_names: - self.mock_module(module_name) - - def _add_if_pattern(self, potential_pattern: str, action: Callable[[str], None]): - if '*' in potential_pattern or '?' in potential_pattern: - self.patterns.append((_module_glob_to_re(potential_pattern), action)) - return True - return False - def _module_is_already_provided(self, qualified_name: str) -> bool: for mod in self.external: if qualified_name == mod or qualified_name.startswith(mod + '.'): @@ -411,22 +405,18 @@ def close(self): ... """ if self.verbose: - print(f"Dependency graph for exported package: {self._write_dep_graph()}") + print(f"Dependency graph for exported package: \n{self._write_dep_graph()}") # Write each tensor to a file named tensor/the_tensor_key in the zip archive for key in sorted(self.serialized_storages.keys()): name = 'data/{}'.format(key) storage = self.serialized_storages[key] - if storage.device.type == 'cpu': - # If it's on the CPU we can directly copy it into the zip file - num_bytes = storage.size() * storage.element_size() - self.zip_file.write_record(name, storage.data_ptr(), num_bytes) - else: - # Copy to a buffer, then serialize that - buf = io.BytesIO() - storage._write_file(buf, _should_read_directly(buf)) - buf_value = buf.getvalue() - self._write(name, buf_value) + # location information is saved in python, but to actually + # get the data from non cpu tensors we need to move them over first + if storage.device.type != 'cpu': + storage = storage.cpu() + num_bytes = storage.size() * storage.element_size() + self.zip_file.write_record(name, storage.data_ptr(), num_bytes) contents = ('\n'.join(self.external) + '\n') self._write('extern_modules', contents) del self.zip_file @@ -441,7 +431,6 @@ def _can_implicitly_extern(self, module_name: str): return module_name == 'torch' or (module_name not in _DISALLOWED_MODULES and _is_builtin_or_stdlib_module(self._import_module(module_name))) - # even though these are in the standard library, we do not allow them to be # automatically externed since they offer a lot of system level access _DISALLOWED_MODULES = ['sys', 'io'] @@ -471,8 +460,41 @@ def _read_file(filename: str) -> str: b = f.read() return b.decode('utf-8') -_glob_re_filter = {'**': '.*', '*': '[^.]*', '?': '.', '.': '\\.'} -_glob_split = re.compile(f'({"|".join(re.escape(x) for x in _glob_re_filter.keys())})') -def _module_glob_to_re(module_name): - pattern = ''.join(_glob_re_filter.get(x, x) for x in _glob_split.split(module_name)) - return re.compile(pattern) +GlobPattern = Union[str, Iterable[str]] + + +class _GlobGroup: + def __init__(self, include: 'GlobPattern', exclude: 'GlobPattern'): + self._dbg = f'_GlobGroup(include={include}, exclude={exclude})' + self.include = _GlobGroup._glob_list(include) + self.exclude = _GlobGroup._glob_list(exclude) + + def __str__(self): + return self._dbg + + def matches(self, candidate: str) -> bool: + candidate = '.' + candidate + return any(p.fullmatch(candidate) for p in self.include) and all(not p.fullmatch(candidate) for p in self.exclude) + + @staticmethod + def _glob_list(elems: 'GlobPattern'): + if isinstance(elems, str): + return [_GlobGroup._glob_to_re(elems)] + else: + return [_GlobGroup._glob_to_re(e) for e in elems] + + @staticmethod + def _glob_to_re(pattern: str): + # to avoid corner cases for the first component, we prefix the candidate string + # with '.' so `import torch` will regex against `.torch` + def component_to_re(component): + if '**' in component: + if component == '**': + return '(\\.[^.]+)*' + else: + raise ValueError('** can only appear as an entire path segment') + else: + return '\\.' + '[^.]*'.join(re.escape(x) for x in component.split('*')) + + result = ''.join(component_to_re(c) for c in pattern.split('.')) + return re.compile(result) diff --git a/torch/package/importer.py b/torch/package/importer.py index 455119d18a1e..ffd474733021 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -140,7 +140,7 @@ def load_pickle(self, package: str, resource: str, map_location=None) -> Any: def _read_extern(self): return self.zip_reader.get_record('extern_modules').decode('utf-8').splitlines(keepends=False) - def _make_module(self, name: str, filename: Optional[str], is_package: bool): + def _make_module(self, name: str, filename: Optional[str], is_package: bool, parent: str): spec = importlib.machinery.ModuleSpec(name, self, is_package=is_package) # type: ignore module = importlib.util.module_from_spec(spec) self.modules[name] = module @@ -150,12 +150,18 @@ def _make_module(self, name: str, filename: Optional[str], is_package: bool): ns['__file__'] = filename ns['__cached__'] = None ns['__builtins__'] = self.patched_builtins + + # pre-emptively install on the parent to prevent IMPORT_FROM from trying to + # access sys.modules + self._install_on_parent(parent, name, module) + if filename is not None: code = self._compile_source(filename) exec(code, ns) + return module - def _load_module(self, name: str): + def _load_module(self, name: str, parent: str): cur : _PathNode = self.root for atom in name.split('.'): if not isinstance(cur, _PackageNode) or atom not in cur.children: @@ -166,7 +172,7 @@ def _load_module(self, name: str): if isinstance(cur, _ExternNode): module = self.modules[name] = importlib.import_module(name) return module - return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode)) # type: ignore + return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore def _compile_source(self, fullpath): source = self.zip_reader.get_record(fullpath) @@ -179,6 +185,14 @@ def get_source(self, module_name) -> str: module = self.import_module(module_name) return self.zip_reader.get_record(module.__file__).decode('utf-8') + def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): + if not parent: + return + # Set the module as an attribute on its parent. + parent_module = self.modules[parent] + if parent_module.__loader__ is self: # type: ignore + setattr(parent_module, name.rpartition('.')[2], module) + # note: copied from cpython's import code, with call to create module replaced with _make_module def _do_find_and_load(self, name): path = None @@ -196,13 +210,10 @@ def _do_find_and_load(self, name): msg = (_ERR_MSG + '; {!r} is not a package').format(name, parent) raise ModuleNotFoundError(msg, name=name) from None - module = self._load_module(name) + module = self._load_module(name, parent) + + self._install_on_parent(parent, name, module) - if parent: - # Set the module as an attribute on its parent. - parent_module = self.modules[parent] - if parent_module.__loader__ is self: # type: ignore - setattr(parent_module, name.rpartition('.')[2], module) return module # note: copied from cpython's import code diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index b5cf78b05f33..5aabbd66c4b1 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -35,7 +35,7 @@ def fuse(self, model: GraphModule, self.modules = dict(input_root.named_modules()) additional_fusion_patterns = \ - fuse_custom_config_dict.get("additional_quant_pattern", {}) + fuse_custom_config_dict.get("additional_fusion_pattern", {}) fusion_patterns = get_combined_dict( get_default_fusion_patterns(), additional_fusion_patterns) # find fusion diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index 146dad1eab2e..fe13d0a3fed7 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -56,6 +56,12 @@ def insert(fn): def input_output_observed(qh): return type(qh) not in DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER + +class MatchAllNode: + """ A node pattern that matches all nodes + """ + pass + # Example use of register pattern function: # @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) # class ConvBNReLUFusion(): @@ -79,6 +85,9 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): self_match = pattern arg_matches = [] + if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): + return True + if len(node.users) > max_uses: return False diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py new file mode 100644 index 000000000000..6326a2e0da59 --- /dev/null +++ b/torch/quantization/fx/qconfig_utils.py @@ -0,0 +1,89 @@ +from .utils import _parent_name +from collections import OrderedDict +import re + +def get_flattened_qconfig_dict(qconfig_dict): + """ flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened = dict() + if '' in qconfig_dict: + flattened[''] = qconfig_dict[''] + + def flatten_key(key): + if key in qconfig_dict: + for obj, qconfig in qconfig_dict[key]: + flattened[obj] = qconfig + + flatten_key('object_type') + flatten_key('module_name') + return flattened + +def convert_dict_to_ordered_dict(qconfig_dict): + """ Convert dict in qconfig_dict to ordered dict + """ + # convert a qconfig list for a type to OrderedDict + def _convert_to_ordered_dict(key, qconfig_dict): + qconfig_dict[key] = OrderedDict(qconfig_dict.get(key, [])) + + _convert_to_ordered_dict('object_type', qconfig_dict) + _convert_to_ordered_dict('module_name_regex', qconfig_dict) + _convert_to_ordered_dict('module_name', qconfig_dict) + +def get_module_type_qconfig(qconfig_dict, module_type, fallback_qconfig): + return qconfig_dict['object_type'].get( + module_type, fallback_qconfig) + +def get_function_qconfig(qconfig_dict, function, fallback_qconfig): + return qconfig_dict['object_type'].get(function, fallback_qconfig) + +def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig): + for regex_pattern, qconfig in \ + qconfig_dict['module_name_regex'].items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + +def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): + if module_name == '': + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_dict['module_name']: + return qconfig_dict['module_name'][module_name] + else: + parent, _ = _parent_name(module_name) + return get_module_name_qconfig(qconfig_dict, parent, fallback_qconfig) + +# get qconfig for module_name, +# fallback to module_name_regex_qconfig, module_type_qconfig, +# global_qconfig if necessary +def get_qconfig(modules, qconfig_dict, module_name, global_qconfig): + assert modules is not None + module_type_qconfig = get_module_type_qconfig( + qconfig_dict, type(modules[module_name]), global_qconfig) + module_name_regex_qconfig = get_module_name_regex_qconfig( + qconfig_dict, module_name, module_type_qconfig) + module_name_qconfig = get_module_name_qconfig( + qconfig_dict, module_name, module_name_regex_qconfig) + return module_name_qconfig diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 176cd7603286..73590ad60904 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -11,6 +11,7 @@ from ..quantization_mappings import ( get_static_quant_module_class, + get_dynamic_quant_module_class, get_quantized_operator, ) from ..utils import ( @@ -471,7 +472,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] assert node.op == 'call_module' emb_node = node - emb = quantizer.modules[emb_node.target] qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) if dtypes not in supported_dtypes: @@ -481,6 +481,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + emb = quantizer.modules[emb_node.target] qemb = get_static_quant_module_class(type(emb)) quantized = qemb.from_float(emb) parent_name, name = _parent_name(emb_node.target) @@ -491,6 +492,48 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, load_arg(quantized=False)(emb_node.args), load_arg(quantized=False)(emb_node.kwargs)) +# TODO (maybe): merge with embedding quantize handler +@register_quant_pattern(torch.nn.GRUCell) +@register_quant_pattern(torch.nn.LSTMCell) +@register_quant_pattern(torch.nn.RNNCell) +@register_quant_pattern(torch.nn.LSTM) +@mark_input_output_not_observed() +class RNNDynamic(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation | weight | activation_compute_type + # dynamic | float32 | qint8 | quint8 + # dynamic | float16 | float16 | None + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.float32, torch.qint8, torch.quint8), + (torch.float16, torch.float16, None), + ] + assert node.op == 'call_module' + qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Embedding/EmbeddingBag, " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + + module = quantizer.modules[node.target] + qmodule_cls = get_dynamic_quant_module_class(type(module)) + qmodule = qmodule_cls.from_float(module) + parent_name, name = _parent_name(node.target) + setattr(quantizer.modules[parent_name], name, qmodule) + return quantizer.quantized_graph.create_node( + 'call_module', + node.target, + load_arg(quantized=False)(node.args), + load_arg(quantized=False)(node.kwargs)) ARGS_TO_SKIP = { torch._ops.ops.quantized.hardswish: ['inplace'], diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index fe7dc53a8019..7da165b52309 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -51,11 +51,15 @@ _parent_name, quantize_node, get_custom_module_class_keys, + get_new_attr_name_with_prefix, + collect_producer_nodes, + graph_module_from_producer_nodes, + assert_and_get_unique_device, ) -from collections import OrderedDict +from .qconfig_utils import * + import warnings -import re from typing import Optional, Dict, Any, List, Union, Tuple, Set, Callable @@ -70,183 +74,9 @@ # Helper Functions # ------------------------ -# Returns a function that can get a new attribute name for module with given -# prefix, for example, -# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') -# >> new_name = get_new_observer_name(module) -# new_name will be an unused attribute name on module, e.g. `_observer_1` -def get_new_attr_name_with_prefix(prefix: str) -> Callable: - def get_new_attr_name(module: torch.nn.Module): - def get_attr_name(i: int): - return prefix + str(i) - i = 0 - attr_name = get_attr_name(i) - while hasattr(module, attr_name): - i += 1 - attr_name = get_attr_name(i) - return attr_name - return get_new_attr_name - -def collect_producer_nodes(node: Node) -> Optional[List[Node]]: - r''' Starting from a target node, trace back until we hit inpu or - getattr node. This is used to extract the chain of operators - starting from getattr to the target node, for example - def forward(self, x): - observed = self.observer(self.weight) - return F.linear(x, observed) - collect_producer_nodes(observed) will either return a list of nodes that - produces the observed node or None if we can't extract a self contained - graph without free variables(inputs of the forward function). - ''' - nodes = [node] - frontier = [node] - while frontier: - node = frontier.pop() - all_args = list(node.args) + list(node.kwargs.values()) - for arg in all_args: - if not isinstance(arg, Node): - continue - if arg.op == 'placeholder': - # hit input, can't fold in this case - return None - nodes.append(arg) - if not (arg.op == 'call_function' and arg.target == getattr): - frontier.append(arg) - return nodes - -def graph_module_from_producer_nodes( - root: GraphModule, producer_nodes: List[Node]) -> GraphModule: - r''' Construct a graph module from extracted producer nodes - from `collect_producer_nodes` function - Args: - root: the root module for the original graph - producer_nodes: a list of nodes we use to construct the graph - Return: - A graph module constructed from the producer nodes - ''' - assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' - # since we traced back from node to getattrr - producer_nodes.reverse() - graph = Graph() - env: Dict[Any, Any] = {} - - def load_arg(a): - return map_arg(a, lambda node: env[node]) - for producer_node in producer_nodes: - env[producer_node] = graph.node_copy(producer_node, load_arg) - graph.output(load_arg(producer_nodes[-1])) - graph_module = GraphModule(root, graph) - return graph_module - -def assert_and_get_unique_device(module: torch.nn.Module) -> Any: - """ - Returns the unique device for a module, or None if no device is found. - Throws an error if multiple devices are detected. - """ - devices = {p.device for p in module.parameters()} | \ - {p.device for p in module.buffers()} - assert len(devices) <= 1, ( - "prepare only works with cpu or single-device CUDA modules, " - "but got devices {}".format(devices) - ) - device = next(iter(devices)) if len(devices) > 0 else None - return device - -def is_observed_standalone_module_node( - node: Node, modules: Dict[str, torch.nn.Module]) -> bool: - return node.op == 'call_module' and \ - is_observed_standalone_module(modules[node.target]) # type: ignore - - -def get_flattened_qconfig_dict(qconfig_dict): - """ flatten the global, object_type and module_name qconfig - to the same qconfig_dict so that it can be used by - propagate_qconfig_ function. - "module_name_regex" is ignored for now since it's not supported - in propagate_qconfig_, but it can be fixed later. - - For example: - Input: { - "": qconfig, - "object_type": [ - (torch.add, qconfig) - ], - "module_name": [ - ("conv", qconfig) - ] - } - - Output: { - "": qconfig, - torch.add: qconfig, - "conv": qconfig - } - """ - flattened = dict() - if '' in qconfig_dict: - flattened[''] = qconfig_dict[''] - - def flatten_key(key): - if key in qconfig_dict: - for obj, qconfig in qconfig_dict[key]: - flattened[obj] = qconfig - - flatten_key('object_type') - flatten_key('module_name') - return flattened - -def convert_dict_to_ordered_dict(qconfig_dict): - """ Convert dict in qconfig_dict to ordered dict - """ - # convert a qconfig list for a type to OrderedDict - def _convert_to_ordered_dict(key, qconfig_dict): - qconfig_dict[key] = OrderedDict(qconfig_dict.get(key, [])) - - _convert_to_ordered_dict('object_type', qconfig_dict) - _convert_to_ordered_dict('module_name_regex', qconfig_dict) - _convert_to_ordered_dict('module_name', qconfig_dict) - -def get_module_type_qconfig(qconfig_dict, module_type, fallback_qconfig): - return qconfig_dict['object_type'].get( - module_type, fallback_qconfig) - -def get_function_qconfig(qconfig_dict, function, fallback_qconfig): - return qconfig_dict['object_type'].get(function, fallback_qconfig) - -def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig): - for regex_pattern, qconfig in \ - qconfig_dict['module_name_regex'].items(): - if re.match(regex_pattern, module_name): - # first match wins - return qconfig - return fallback_qconfig - -def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): - if module_name == '': - # module name qconfig not found - return fallback_qconfig - if module_name in qconfig_dict['module_name']: - return qconfig_dict['module_name'][module_name] - else: - parent, _ = _parent_name(module_name) - return get_module_name_qconfig(qconfig_dict, parent, fallback_qconfig) - -# get qconfig for module_name, -# fallback to module_name_regex_qconfig, module_type_qconfig, -# global_qconfig if necessary -def get_qconfig(modules, qconfig_dict, module_name, global_qconfig): - assert modules is not None - module_type_qconfig = get_module_type_qconfig( - qconfig_dict, type(modules[module_name]), global_qconfig) - module_name_regex_qconfig = get_module_name_regex_qconfig( - qconfig_dict, module_name, module_type_qconfig) - module_name_qconfig = get_module_name_qconfig( - qconfig_dict, module_name, module_name_regex_qconfig) - return module_name_qconfig - def insert_observer( node: Node, observer: torch.quantization.ObserverBase, - model_device: Any, model: torch.nn.Module, + model: torch.nn.Module, activation_post_process_map: Dict[str, torch.quantization.ObserverBase], env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str]): @@ -257,6 +87,7 @@ def insert_observer( observer: observer/fake_quantize module instance """ # respect device affinity when adding observers + model_device = assert_and_get_unique_device(model) if model_device: observer.to(model_device) # add observer module as attribute @@ -313,7 +144,6 @@ def insert_observer_for_output_of_the_node( modules: Dict[str, torch.nn.Module], model: torch.nn.Module, pattern: Any, - model_device: Any, activation_post_process_map: Dict[str, torch.quantization.ObserverBase], env: Dict[Any, Any], observed_graph: Graph, @@ -338,7 +168,7 @@ def insert_observer_for_output_of_the_node( "activation_post_process constructor not provided " + \ "for pattern:" + str(pattern) insert_observer( - node, activation_post_process_ctr(), model_device, + node, activation_post_process_ctr(), model, activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) elif (isinstance(quantize_handler, @@ -386,13 +216,13 @@ def input_is_observed(arg): # observer for outputs new_observer = qconfig.activation() insert_observer( - node, new_observer, model_device, model, + node, new_observer, model, activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) def insert_observer_for_input_arg_of_observed_node( node: Node, observed_node_names_set: Set[str], quants: Dict[str, Any], - model_device: Any, model: torch.nn.Module, + model: torch.nn.Module, activation_post_process_map: Dict[str, torch.quantization.ObserverBase], env: Dict[str, str], observed_graph: Graph, load_arg: Callable): @@ -401,7 +231,7 @@ def insert_observer_for_input_arg_of_observed_node( if activation_post_process_ctr is not None: insert_observer( node, activation_post_process_ctr(), - model_device, model, activation_post_process_map, + model, activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) # A dictionary for querying the weight index for a given op @@ -565,7 +395,6 @@ def load_arg(a): get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') - model_device = assert_and_get_unique_device(model) result_node : Optional[Node] = None for node in model.graph.nodes: @@ -591,14 +420,14 @@ def load_arg(a): node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, - model_device, self.activation_post_process_map, env, + self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, matched_nodes) else: env[node.name] = observed_graph.node_copy(node, load_arg) insert_observer_for_input_arg_of_observed_node( node, observed_node_names_set, quants, - model_device, model, self.activation_post_process_map, env, + model, self.activation_post_process_map, env, observed_graph, load_arg) @@ -852,8 +681,11 @@ def insert_quantize_node(node): quantized = False else: assert obj is not None - is_standalone_module_node = is_observed_standalone_module_node( - node, self.modules) + is_standalone_module_node = ( + node.op == 'call_module' and + is_observed_standalone_module( + self.modules[node.target]) # type: ignore + ) result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index a07cbc6ef8e4..c1f849803342 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -2,6 +2,15 @@ import torch from ..utils import is_per_tensor, is_per_channel +from torch.fx import GraphModule, map_arg + +from torch.fx.graph import ( + Graph, + Node, +) + +from typing import Callable, Optional, List, Dict, Any + # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): r = target.rsplit('.', 1) @@ -169,3 +178,85 @@ def get_linear_prepack_op_for_dtype(dtype): return torch.ops.quantized.linear_prepack else: raise Exception("can't get linear prepack op for dtype:", dtype) + +# Returns a function that can get a new attribute name for module with given +# prefix, for example, +# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') +# >> new_name = get_new_observer_name(module) +# new_name will be an unused attribute name on module, e.g. `_observer_1` +def get_new_attr_name_with_prefix(prefix: str) -> Callable: + def get_new_attr_name(module: torch.nn.Module): + def get_attr_name(i: int): + return prefix + str(i) + i = 0 + attr_name = get_attr_name(i) + while hasattr(module, attr_name): + i += 1 + attr_name = get_attr_name(i) + return attr_name + return get_new_attr_name + +def collect_producer_nodes(node: Node) -> Optional[List[Node]]: + r''' Starting from a target node, trace back until we hit inpu or + getattr node. This is used to extract the chain of operators + starting from getattr to the target node, for example + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + collect_producer_nodes(observed) will either return a list of nodes that + produces the observed node or None if we can't extract a self contained + graph without free variables(inputs of the forward function). + ''' + nodes = [node] + frontier = [node] + while frontier: + node = frontier.pop() + all_args = list(node.args) + list(node.kwargs.values()) + for arg in all_args: + if not isinstance(arg, Node): + continue + if arg.op == 'placeholder': + # hit input, can't fold in this case + return None + nodes.append(arg) + if not (arg.op == 'call_function' and arg.target == getattr): + frontier.append(arg) + return nodes + +def graph_module_from_producer_nodes( + root: GraphModule, producer_nodes: List[Node]) -> GraphModule: + r''' Construct a graph module from extracted producer nodes + from `collect_producer_nodes` function + Args: + root: the root module for the original graph + producer_nodes: a list of nodes we use to construct the graph + Return: + A graph module constructed from the producer nodes + ''' + assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' + # since we traced back from node to getattrr + producer_nodes.reverse() + graph = Graph() + env: Dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node]) + for producer_node in producer_nodes: + env[producer_node] = graph.node_copy(producer_node, load_arg) + graph.output(load_arg(producer_nodes[-1])) + graph_module = GraphModule(root, graph) + return graph_module + +def assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | \ + {p.device for p in module.buffers()} + assert len(devices) <= 1, ( + "prepare only works with cpu or single-device CUDA modules, " + "but got devices {}".format(devices) + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 88d264b1ccf3..c965de07deb7 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -124,6 +124,19 @@ def get_static_quant_module_class(float_module_class, additional_static_quant_ma " does not have a corresponding quantized module class" return static_quant_module_class +def get_dynamic_quant_module_class(float_module_class, additional_dynamic_quant_mapping=None): + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, \ + "Floating point module class {}".format(str(float_module_class)) + \ + " does not have a corresponding quantized module class" + return dynamic_quant_module_class + def get_default_qat_module_mappings(): ''' Get default module mapping for quantization aware training ''' diff --git a/torch/serialization.py b/torch/serialization.py index 7ae6abafa232..ebc5d0a08541 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -524,7 +524,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): deserialization methods using :func:`torch.serialization.register_package`. Args: - f: a file-like object (has to implement :meth:`read`, :meth`readline`, :meth`tell`, and :meth`seek`), + f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), or a string or os.PathLike object containing a file name map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 0126b1dd0a93..36f02eff0c0f 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -171,7 +171,7 @@ def _construct_test_name(test_name, op, device_type, dtype): if op is not None: - test_name += "_" + op.name + test_name += "_" + op.name.replace('.', '_') test_name += "_" + device_type diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 96d1cd03557e..b88dcaaccb33 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10,11 +10,11 @@ from typing import List, Tuple, Dict, Any from torch.testing import \ - (make_non_contiguous, _dispatch_dtypes, - floating_types, floating_types_and, floating_and_complex_types, - floating_and_complex_types_and, all_types_and_complex_and, all_types_and) + (make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and, + floating_and_complex_types, floating_and_complex_types_and, + all_types_and_complex_and, all_types_and) from torch.testing._internal.common_device_type import \ - (skipCUDAIfNoMagma, skipCPUIfNoLapack, + (skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride) from torch.testing._internal.common_utils import \ (prod_single_zero, random_square_matrix_of_rank, @@ -22,7 +22,7 @@ random_symmetric_pd_matrix, make_nonzero_det, random_fullrank_matrix_distinct_singular_value, set_rng_seed, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, - torch_to_numpy_dtype_dict) + torch_to_numpy_dtype_dict, TEST_WITH_SLOW) if TEST_SCIPY: import scipy.special @@ -54,6 +54,23 @@ def __init__(self, input, *, args=tuple(), kwargs=None): self.kwargs = kwargs if kwargs is not None else {} +_NOTHING = object() # Unique value to distinguish default from anything else + + +# Extension of getattr to support qualified names +# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm +def _getattr_qual(obj, name, default=_NOTHING): + try: + for path in name.split('.'): + obj = getattr(obj, path) + return obj + except AttributeError: + if default is not _NOTHING: + return default + else: + raise + + # Classes and methods for the operator database class OpInfo(object): """Operator information and helper functions for acquiring it.""" @@ -84,13 +101,16 @@ def __init__(self, skips=tuple(), # information about which tests to skip decorators=None, # decorators to apply to generated tests promotes_integers_to_float=False, # whether op promotes unary output to float or not - sample_inputs_func=None): # function to generate sample inputs + sample_inputs_func=None, # function to generate sample inputs + aten_name=None, # name of the corresponding aten:: operator + ): # Validates the dtypes are generated from the dispatch-related functions for dtype_list in (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM): assert isinstance(dtype_list, (_dispatch_dtypes, type(None))) self.name = name + self.aten_name = aten_name if aten_name is not None else name self.dtypes = set(dtypes) self.dtypesIfCPU = set(dtypesIfCPU) if dtypesIfCPU is not None else self.dtypes @@ -99,12 +119,10 @@ def __init__(self, self._default_test_dtypes = set(default_test_dtypes) if default_test_dtypes is not None else None # NOTE: if the op is unspecified it is assumed to be under the torch namespace - if op is None: - assert hasattr(torch, self.name), f"Can't find torch.{self.name}" - self.op = op if op else getattr(torch, self.name) - self.method_variant = getattr(torch.Tensor, name) if hasattr(torch.Tensor, name) else None + self.op = op if op else _getattr_qual(torch, self.name) + self.method_variant = getattr(torch.Tensor, name, None) inplace_name = name + "_" - self.inplace_variant = getattr(torch.Tensor, inplace_name) if hasattr(torch.Tensor, name) else None + self.inplace_variant = getattr(torch.Tensor, inplace_name, None) self.skip_bfloat16_grad = skip_bfloat16_grad self.test_inplace_grad = test_inplace_grad @@ -289,8 +307,71 @@ def wrapped_fn(x): return wrapped_fn + +# Metadata class for Fast Fourier Transforms in torch.fft. +class SpectralFuncInfo(OpInfo): + """Operator information for torch.fft transforms. """ + + def __init__(self, + name, # the string name of the function + *, + ref=None, # Reference implementation (probably in np.fft namespace) + dtypes=floating_and_complex_types(), + dtypesIfCPU=None, + dtypesIfCUDA=None, + dtypesIfROCM=None, + ndimensional: bool, # Whether dim argument can be a tuple + skips=None, + decorators=None, + **kwargs): + dtypesIfCPU = dtypesIfCPU if dtypesIfCPU is not None else dtypes + dtypesIfCUDA = dtypesIfCUDA if dtypesIfCUDA is not None else dtypes + dtypesIfROCM = dtypesIfROCM if dtypesIfROCM is not None else dtypes + + # gradgrad is quite slow + if not TEST_WITH_SLOW: + skips = skips if skips is not None else [] + skips.append(SkipInfo('TestGradients', 'test_fn_gradgrad')) + + decorators = decorators if decorators is not None else [] + decorators += [skipCPUIfNoMkl, skipCUDAIfRocm] + + super().__init__(name=name, + dtypes=dtypes, + dtypesIfCPU=dtypesIfCPU, + dtypesIfCUDA=dtypesIfCUDA, + dtypesIfROCM=dtypesIfROCM, + skips=skips, + decorators=decorators, + **kwargs) + self.ref = ref if ref is not None else _getattr_qual(np, name) + self.ndimensional = ndimensional + + + def sample_inputs(self, device, dtype, requires_grad=False): + tensor = make_tensor((L, M), device, dtype, + low=None, high=None, + requires_grad=requires_grad) + if self.ndimensional: + return [ + SampleInput(tensor), + SampleInput(tensor, kwargs=dict(dim=(-2,))), + SampleInput(tensor, kwargs=dict(norm='ortho')), + SampleInput(tensor, kwargs=dict(s=(10, 15))), + SampleInput(tensor, kwargs=dict(s=10, dim=1, norm='ortho')), + ] + else: + return [ + SampleInput(tensor), + SampleInput(tensor, kwargs=dict(dim=-2)), + SampleInput(tensor, kwargs=dict(norm='ortho')), + SampleInput(tensor, kwargs=dict(n=15)), + SampleInput(tensor, kwargs=dict(n=10, dim=1, norm='ortho')), + ] + + # Operator database (sorted alphabetically) -op_db: List[Any] = [ +op_db: List[OpInfo] = [ # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) UnaryUfuncInfo('acos', ref=np.arccos, @@ -437,10 +518,15 @@ def wrapped_fn(x): dtypes=[torch.float], active_if=TEST_WITH_ROCM), )), UnaryUfuncInfo('cosh', - ref=np.cosh, - dtypesIfCPU=floating_and_complex_types(), + ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, assert_autodiffed=True, skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.int8]), SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', @@ -448,6 +534,89 @@ def wrapped_fn(x): SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]), )), + SpectralFuncInfo('fft.fft', + aten_name='fft_fft', + ref=np.fft.fft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.fftn', + aten_name='fft_fftn', + ref=np.fft.fftn, + ndimensional=True, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False, + decorators=[precisionOverride( + {torch.float: 1e-4, torch.cfloat: 1e-4})],), + SpectralFuncInfo('fft.hfft', + aten_name='fft_hfft', + ref=np.fft.hfft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.rfft', + aten_name='fft_rfft', + ref=np.fft.rfft, + ndimensional=False, + dtypes=all_types_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.rfftn', + aten_name='fft_rfftn', + ref=np.fft.rfftn, + ndimensional=True, + dtypes=all_types_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False, + decorators=[precisionOverride({torch.float: 1e-4})],), + SpectralFuncInfo('fft.ifft', + aten_name='fft_ifft', + ref=np.fft.ifft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.ifftn', + aten_name='fft_ifftn', + ref=np.fft.ifftn, + ndimensional=True, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.ihfft', + aten_name='fft_ihfft', + ref=np.fft.ihfft, + ndimensional=False, + dtypes=all_types_and(torch.bool), + default_test_dtypes=floating_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.irfft', + aten_name='fft_irfft', + ref=np.fft.irfft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), + SpectralFuncInfo('fft.irfftn', + aten_name='fft_irfftn', + ref=np.fft.irfftn, + ndimensional=True, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=False, + test_inplace_grad=False,), UnaryUfuncInfo('log', ref=np.log, domain=(0, float('inf')), @@ -590,10 +759,26 @@ def wrapped_fn(x): active_if=(IS_MACOS or IS_WINDOWS)), )), UnaryUfuncInfo('exp2', - ref=np.exp2, - dtypes=floating_types_and(torch.half), - dtypesIfCPU=None, - dtypesIfCUDA=None), + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), + dtypes=all_types_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and(torch.bool, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + promotes_integers_to_float=True), + UnaryUfuncInfo('expm1', + ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), + dtypes=all_types_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + promotes_integers_to_float=True, + assert_autodiffed=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/48926#issuecomment-739734774 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.bfloat16]), + # RuntimeError: "isfinite" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cpu', dtypes=[torch.bfloat16]), + )), UnaryUfuncInfo('nan_to_num', ref=np.nan_to_num, dtypes=all_types_and(torch.half, torch.bool), @@ -616,25 +801,13 @@ def wrapped_fn(x): # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]), - # RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype. - SkipInfo('TestGradients', 'test_fn_grad', - dtypes=[torch.cdouble]), - SkipInfo('TestGradients', 'test_fn_gradgrad', - dtypes=[torch.cdouble]), - SkipInfo('TestGradients', 'test_method_grad', - dtypes=[torch.cdouble]), - SkipInfo('TestGradients', 'test_method_gradgrad', - dtypes=[torch.cdouble]), - SkipInfo('TestGradients', 'test_inplace_grad', - dtypes=[torch.cdouble]), - SkipInfo('TestGradients', 'test_inplace_gradgrad', - dtypes=[torch.cdouble]), SkipInfo('TestCommon', 'test_variant_consistency_eager', dtypes=[torch.cfloat, torch.cdouble]), SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=[torch.cfloat, torch.cdouble])), promotes_integers_to_float=True, - handles_complex_extremals=False), + handles_complex_extremals=False, + test_complex_grad=False), ] if TEST_SCIPY: @@ -644,7 +817,7 @@ def reference_sigmoid(x): return (1 / (1 + np.exp(-x))) return scipy.special.expit(x) - op_db_scipy_reference = [ + op_db_scipy_reference: List[OpInfo] = [ UnaryUfuncInfo('sigmoid', ref=reference_sigmoid, decorators=(precisionOverride({torch.float16: 1e-2, @@ -695,6 +868,7 @@ def reference_sigmoid(x): # Common operator groupings unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)] +spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)] def index_variable(shape, max_indices): if not isinstance(shape, tuple): @@ -954,8 +1128,6 @@ def method_tests(): ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)), ('exp', (S, S, S), NO_ARGS, '', (True,)), ('exp', (), NO_ARGS, 'scalar', (True,)), - ('expm1', (S, S, S), NO_ARGS, '', (True,)), - ('expm1', (), NO_ARGS, 'scalar', (True,)), ('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS), ('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''), diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 2e3cc16b4540..2ff28c8d30ad 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -742,8 +742,8 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic self.assertTrue(expected_name in str(q_embeddingbag)) -# Below are a series of neural net models to use in testing quantization -# Single layer models +# Below are a series of toy models to use in testing quantization + class SingleLayerLinearModel(torch.nn.Module): def __init__(self): super().__init__() @@ -1350,7 +1350,7 @@ def __init__(self): self.downsample = torch.nn.Identity() self.myop = nn.quantized.FloatFunctional() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - + self.fc = torch.nn.Linear(inplanes, 1) def forward(self, x): out = self.conv1(x) @@ -1360,8 +1360,13 @@ def forward(self, x): out = self.myop.add(out, identity) out = self.relu2(out) out = self.avgpool(out) + out = torch.flatten(out, 1) + out = self.fc(out) return out + def fuse_model(self): + torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) + class ModelMultipleOps(torch.nn.Module): def __init__(self): super().__init__() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index cf997ddb894b..6577b1c4559f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -21,6 +21,7 @@ import warnings import random import contextlib +import shutil import socket import subprocess import time @@ -300,11 +301,11 @@ def run_tests(argv=UNITTEST_ARGS): if IS_WINDOWS: @contextmanager - def TemporaryFileName(): + def TemporaryFileName(dir=None): # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile # opens the file, and it cannot be opened multiple times in Windows. To support Windows, # close the file after creation and try to remove it manually - f = tempfile.NamedTemporaryFile(delete=False) + f = tempfile.NamedTemporaryFile(delete=False, dir=dir) try: f.close() yield f.name @@ -312,10 +313,27 @@ def TemporaryFileName(): os.unlink(f.name) else: @contextmanager # noqa: T484 - def TemporaryFileName(): - with tempfile.NamedTemporaryFile() as f: + def TemporaryFileName(dir=None): + with tempfile.NamedTemporaryFile(dir=dir) as f: yield f.name +if IS_WINDOWS: + @contextmanager + def TemporaryDirectoryName(suffix=None): + # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely, + # so we first create the directory using mkdtemp and then remove it manually + try: + dir_name = tempfile.mkdtemp(suffix=suffix) + yield dir_name + finally: + shutil.rmtree(dir_name) +else: + @contextmanager # noqa: T484 + def TemporaryDirectoryName(suffix=None): + with tempfile.TemporaryDirectory(suffix=suffix) as d: + yield d + +IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' def _check_module_exists(name): r"""Returns if a top-level module with :attr:`name` exists *without** @@ -831,7 +849,7 @@ def __init__(self, method_name='runTest'): # Wraps the tested method if we should enforce non default CUDA stream. self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True) - if self._do_cuda_non_default_stream and not IS_WINDOWS and not TEST_WITH_ROCM: + if self._do_cuda_non_default_stream and not IS_WINDOWS: self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream) def assertLeaksNoCudaTensors(self, name=None): @@ -1039,6 +1057,13 @@ def _compareScalars(self, a, b, *, return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + # Construct assert messages basd on internal debug message and user provided message. + def _get_assert_msg(self, msg, debug_msg=None): + if msg is None: + return debug_msg + else: + return f"\n{msg}" if debug_msg is None else f"{debug_msg}\n{msg}" + def assertEqualIgnoreType(self, *args, **kwargs) -> None: # If you are seeing this function used, that means test is written wrongly # and deserves detailed investigation @@ -1049,7 +1074,8 @@ def assertEqualIgnoreType(self, *args, **kwargs) -> None: def assertEqual(self, x, y, msg: Optional[str] = None, *, atol: Optional[float] = None, rtol: Optional[float] = None, equal_nan=True, exact_dtype=True, exact_device=False) -> None: - assert (atol is None) == (rtol is None), "If one of atol or rtol is specified the other must be, too" + assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too" + debug_msg: Optional[str] = None # Tensor x Number and Number x Tensor comparisons if isinstance(x, torch.Tensor) and isinstance(y, Number): @@ -1065,39 +1091,42 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, elif isinstance(y, torch.Tensor) and isinstance(x, np.bool_): self.assertEqual(x, y.item(), atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) + # Tensor x Tensor elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): - super().assertEqual(x.is_sparse, y.is_sparse, msg=msg) - super().assertEqual(x.is_quantized, y.is_quantized, msg=msg) + debug_msg = ("Attempted to compare with different is_sparse settings: " + f"Expected: {x.is_sparse}; Actual: {y.is_sparse}.") + super().assertEqual(x.is_sparse, y.is_sparse, msg=self._get_assert_msg(msg=msg, debug_msg=debug_msg)) + debug_msg = ("Attempted to compare with different is_quantized settings: " + f"Expected: {x.is_quantized}; Actual: {y.is_quantized}.") + super().assertEqual(x.is_quantized, y.is_quantized, msg=self._get_assert_msg(msg=msg, debug_msg=debug_msg)) if x.is_sparse: if x.size() != y.size(): - debug_msg_sparse = ("Attempted to compare equality of tensors with different sizes. " - f"Got sizes {x.size()} and {y.size()}.") - if msg is None: - msg = debug_msg_sparse - self.assertTrue(False, msg=msg) + debug_msg_sparse = ("Attempted to compare equality of tensors with different sizes: " + f"Expected: {x.size()}; Actual: {y.size()}.") + super().assertTrue(False, msg=self._get_assert_msg(msg=msg, debug_msg=debug_msg_sparse)) x = x.coalesce() y = y.coalesce() - indices_result, debug_msg = self._compareTensors(x._indices(), y._indices(), - rtol=rtol, atol=atol, - equal_nan=equal_nan, exact_dtype=exact_dtype, - exact_device=exact_device) - - if not indices_result and msg is None: - assert debug_msg is not None - msg = "Sparse tensor indices failed to compare as equal! " + debug_msg - self.assertTrue(indices_result, msg=msg) - - values_result, debug_msg = self._compareTensors(x._values(), y._values(), - rtol=rtol, atol=atol, - equal_nan=equal_nan, exact_dtype=exact_dtype, - exact_device=exact_device) - - if not values_result and msg is None: - assert debug_msg is not None - msg = "Sparse tensor values failed to compare as equal! " + debug_msg - self.assertTrue(values_result, msg=msg) + indices_result, debug_msg_indices = self._compareTensors(x._indices(), y._indices(), + rtol=rtol, atol=atol, + equal_nan=equal_nan, exact_dtype=exact_dtype, + exact_device=exact_device) + + if not indices_result: + assert debug_msg_indices is not None + debug_msg = "Sparse tensor indices failed to compare as equal! " + debug_msg_indices + super().assertTrue(indices_result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) + + values_result, debug_msg_values = self._compareTensors(x._values(), y._values(), + rtol=rtol, atol=atol, + equal_nan=equal_nan, exact_dtype=exact_dtype, + exact_device=exact_device) + + if not values_result: + assert debug_msg_values is not None + debug_msg = "Sparse tensor values failed to compare as equal! " + debug_msg_values + super().assertTrue(values_result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif x.is_quantized and y.is_quantized: self.assertEqual(x.qscheme(), y.qscheme(), atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, @@ -1121,30 +1150,33 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) - result, debug_msg = self._compareTensors(x.int_repr().to(torch.int32), - y.int_repr().to(torch.int32), - atol=atol, rtol=rtol, - exact_dtype=exact_dtype, - exact_device=exact_device) + result, debug_msg_compare = self._compareTensors(x.int_repr().to(torch.int32), + y.int_repr().to(torch.int32), + atol=atol, rtol=rtol, + exact_dtype=exact_dtype, + exact_device=exact_device) - if not result and msg is None: - assert debug_msg is not None - msg = "Quantized representations failed to compare as equal! " + debug_msg - self.assertTrue(result, msg=msg) + if not result: + assert debug_msg_compare is not None + debug_msg = "Quantized representations failed to compare as equal! " + debug_msg_compare + super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) else: - result, debug_msg = self._compareTensors(x, y, rtol=rtol, atol=atol, - equal_nan=equal_nan, exact_dtype=exact_dtype, - exact_device=exact_device) + result, debug_msg_generic = self._compareTensors(x, y, rtol=rtol, atol=atol, + equal_nan=equal_nan, exact_dtype=exact_dtype, + exact_device=exact_device) if not result: - assert debug_msg is not None - msg = msg or "Tensors failed to compare as equal!" - msg = f'{msg}\n{debug_msg}' - self.assertTrue(result, msg=msg) + assert debug_msg_generic is not None + debug_msg = "Tensors failed to compare as equal!" + debug_msg_generic + super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif isinstance(x, string_classes) and isinstance(y, string_classes): - super().assertEqual(x, y, msg=msg) + debug_msg = ("Attempted to compare [string] types: " + f"Expected: {repr(x)}; Actual: {repr(y)}.") + super().assertEqual(x, y, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif type(x) == set and type(y) == set: - super().assertEqual(x, y, msg=msg) + debug_msg = ("Attempted to compare [set] types: " + f"Expected: {x}; Actual: {y}.") + super().assertEqual(x, y, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif isinstance(x, dict) and isinstance(y, dict): if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): self.assertEqual(x.items(), y.items(), atol=atol, rtol=rtol, @@ -1161,23 +1193,27 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, exact_dtype=exact_dtype, exact_device=exact_device) elif isinstance(x, type) and isinstance(y, type): # See TestTorch.test_assert_equal_generic_meta - super().assertEqual(x, y, msg=msg) + debug_msg = ("Attempted to compare [type] types: " + f"Expected: {x}; Actual: {y}.") + super().assertEqual(x, y, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif is_iterable(x) and is_iterable(y): - super().assertEqual(len(x), len(y), msg=msg) + debug_msg = ("Attempted to compare the lengths of [iterable] types: " + f"Expected: {len(x)}; Actual: {len(y)}.") + super().assertEqual(len(x), len(y), msg=self._get_assert_msg(msg, debug_msg=debug_msg)) for x_, y_ in zip(x, y): self.assertEqual(x_, y_, atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) elif isinstance(x, bool) and isinstance(y, bool): - self.assertTrue(x == y, msg=msg) + super().assertTrue(x == y, msg=msg) # Scalar x Scalar elif isinstance(x, Number) and isinstance(y, Number): - result, debug_msg = self._compareScalars(x, y, rtol=rtol, atol=atol, - equal_nan=equal_nan) - if not result and msg is None: - assert debug_msg is not None - msg = "Scalars failed to compare as equal! " + debug_msg - self.assertTrue(result, msg=msg) + result, debug_msg_scalars = self._compareScalars(x, y, rtol=rtol, atol=atol, + equal_nan=equal_nan) + if not result: + assert debug_msg_scalars is not None + debug_msg = "Scalars failed to compare as equal! " + debug_msg_scalars + super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) # Tensor x Numpy array elif isinstance(x, torch.Tensor) and isinstance(y, np.ndarray): self.assertEqual(x, torch.from_numpy(y), atol=atol, rtol=rtol, msg=msg, diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index cbe8e9d630bf..5577d2322679 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -67,6 +67,12 @@ def __eq__(self, other): [1, 2, True, "string", [4, 5, "nested"]], ] +# Allowlist of distributed backends where profiling collectives is supported. +PROFILING_SUPPORTED_BACKENDS = [ + dist.Backend.NCCL, + dist.Backend.GLOO, +] + # Dummy NamedTuple data structures to test DDP support for NamedTuple types. EXPECTED_FIELDS = ("a", "b") TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS) @@ -1283,7 +1289,7 @@ def test_all_reduce_result_cuda(self): self.assertEqual(result, [_build_tensor(src + 1, expected_value)]) self._barrier() - def call_dist_op(self, profiling_title_postfix, is_async, op, *args, expect_event=False, secondary_op_call=None, **kwargs): + def call_dist_op(self, profiling_title_postfix, is_async, op, *args, expect_event=True, secondary_op_call=None, **kwargs): op_calls = [lambda: op(*args, **kwargs)] if secondary_op_call is not None: op_calls.append(secondary_op_call) @@ -1297,12 +1303,12 @@ def call_dist_op(self, profiling_title_postfix, is_async, op, *args, expect_even def get_event(postfix): return [event for event in prof.function_events if event.name.endswith(postfix)] - if expect_event: + if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS: events = get_event(profiling_title_postfix) self.assertEqual(len(events), len(op_calls)) for e in events: self.assertEqual(e.count, 1) - self.assertGreater(e.cpu_time, 0) + self.assertGreaterEqual(e.cpu_time, 0) # ALL REDUCE def _test_all_reduce_helper( @@ -2474,12 +2480,9 @@ def _test_reduce_multigpu_helper( _build_tensor(src + 1, master_value).cuda(device=i) for i in rank_to_GPU[rank] ] - # TODO: Setting expect_event=False to disable profiling - # tests. Once https://github.com/pytorch/pytorch/issues/48127 - # is addressed, this should be reverted. self.call_dist_op( "reduce", False, dist.reduce_multigpu, tensors, src, op, group_id, - expect_event=False) + expect_event=len(tensors) == 1) expected_tensor = _build_tensor(src + 1, expected_value) self.assertEqual(tensors[0], expected_tensor) else: @@ -2487,12 +2490,9 @@ def _test_reduce_multigpu_helper( _build_tensor(src + 1, worker_value).cuda(device=i) for i in rank_to_GPU[rank] ] - # TODO: Setting expect_event=False to disable profiling - # tests. Once https://github.com/pytorch/pytorch/issues/48127 - # is addressed, this should be reverted. self.call_dist_op( "reduce", False, dist.reduce_multigpu, tensors, src, op, group_id, - expect_event=False) + expect_event=len(tensors) == 1) self._barrier() @@ -2532,13 +2532,10 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, d for gpu in rank_to_GPU[rank]: output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu]) - # TODO: Setting expect_event=False to disable profiling - # tests. Once https://github.com/pytorch/pytorch/issues/48127 - # is addressed, this should be reverted. self.call_dist_op( "all_gather", False, dist.all_gather_multigpu, output_tensors, tensors, group_id, - expect_event=False) + expect_event=len(expected_output) == 1) self.assertEqual(output_tensors, expected_output) self._barrier() diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index d6b3d816fe68..4f14584af3b1 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -217,6 +217,21 @@ def test_remote_parameters(self): self.assertEqual(len(param_rrefs), 1) self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL)) + @dist_utils.dist_init + def test_get_module_rref(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + # Only test Python nn.Module, because script module methods don't support ``get_module_rref``. + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + rref = remote_module.get_module_rref() + self.assertEqual(rref, remote_module.module_rref) + for param in rref.to_here().parameters(): + self.assertTrue(torch.equal(param, _PARAM_VAL)) + @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_valid_device(self): @@ -270,16 +285,6 @@ def test_invalid_devices(self): ) ) - with self.assertRaisesRegex( - RuntimeError, r"CPU device index must be -1 or zero, got 2" - ): - list( - self._create_remote_module_iter( - "{}/cpu:2".format(dst_worker_name), - modes=[ModuleCreationMode.MODULE_CTOR], - ) - ) - with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): list( self._create_remote_module_iter( diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 8036c73a6330..4a91394d53c5 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -229,8 +229,15 @@ def the_method({}): return {} ''' +def value_to_literal(value): + if isinstance(value, str): + # Quotes string and escapes special characters + return ascii(value) + else: + return str(value) + def get_call(method_name, func_type, args, kwargs): - kwargs_str = ', '.join([k + '=' + str(v) for k, v in kwargs.items()]) + kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()]) self_arg = args[0] if(func_type == 'method'): args = args[1:] @@ -461,12 +468,12 @@ def make_module(script): return module return script_module -def check_alias_annotation(method_name, args, kwargs): +def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'): formals, tensors, actuals = get_script_args(args) - call = get_call(method_name, 'method', actuals, kwargs) + call = get_call(method_name, func_type, actuals, kwargs) script = script_template.format(', '.join(formals), call) CU = torch.jit.CompilationUnit(script) - torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name) + torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name) def get_nn_module_name_from_kwargs(**kwargs): if 'module_name' in kwargs: diff --git a/torch/testing/check_kernel_launches.py b/torch/testing/check_kernel_launches.py index 091f1be98561..c274316b54fe 100644 --- a/torch/testing/check_kernel_launches.py +++ b/torch/testing/check_kernel_launches.py @@ -18,12 +18,13 @@ # But this should be sufficient to detect and fix most problem # instances and can be refined before the test is made binding kernel_launch_regex = re.compile(r""" - >>> # Identifies kernel launch + ^.*>>> # Identifies kernel launch \s* # Maybe some whitespace (includes newlines) \([^;]+\); # And then arguments in parens and semi-colon (?! # Negative lookahead: we trigger if we don't find the launch guard \s* # Maybe some whitespace (includes newlines) \\? # 0 or 1 backslashes (for launches in preprocessor macros) + \s* # Maybe some whitespace (includes newlines) (?:[0-9]+: )? # Detects and ignores a line numbering, if present \s* # Maybe some whitespace (includes newlines) C10_CUDA_KERNEL_LAUNCH_CHECK\(\); # Kernel launch guard! diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index cde58f85cada..5b91c7a9a0fa 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -43,7 +43,10 @@ def run(command): stderr=subprocess.PIPE, shell=True) raw_output, raw_err = p.communicate() rc = p.returncode - enc = locale.getpreferredencoding() + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() output = raw_output.decode(enc) err = raw_err.decode(enc) return rc, output.strip(), err.strip() @@ -70,7 +73,7 @@ def run_and_parse_first_match(run_lambda, command, regex): def get_conda_packages(run_lambda): if get_platform() == 'win32': - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') grep_cmd = r'{} /R "torch numpy cudatoolkit soumith mkl magma"'.format(findstr_cmd) else: @@ -125,7 +128,7 @@ def get_running_cuda_version(run_lambda): def get_cudnn_version(run_lambda): """This will return a list of libcudnn.so; it's hard to tell which one is being used""" if get_platform() == 'win32': - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") where_cmd = os.path.join(system_root, 'System32', 'where') cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) @@ -163,7 +166,15 @@ def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux smi = 'nvidia-smi' if get_platform() == 'win32': - smi = '"C:\\Program Files\\NVIDIA Corporation\\NVSMI\\%s"' % smi + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = f'"{candidate_smi}"' + break return smi @@ -185,7 +196,7 @@ def get_mac_version(run_lambda): def get_windows_version(run_lambda): - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) @@ -236,7 +247,7 @@ def get_pip_packages(run_lambda): # People generally have `pip` as `pip` or `pip3` def run_with_pip(pip): if get_platform() == 'win32': - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') grep_cmd = r'{} /R "numpy torch"'.format(findstr_cmd) else: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index b84ebe95d525..7837d8cbb570 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -828,6 +828,35 @@ def CUDAExtension(name, sources, *args, **kwargs): cmdclass={ 'build_ext': BuildExtension }) + + Compute capabilities: + + By default the extension will be compiled to run on all archs of the cards visible during the + building process of the extension, plus PTX. If down the road a new card is installed the + extension may need to be recompiled. If a visible card has a compute capability (CC) that's + newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch + will make nvcc fall back to building kernels with the newest version of PTX your nvcc does + support (see below for details on PTX). + + You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which + CCs you want the extension to support: + + TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py + TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py + + The +PTX option causes extension kernel binaries to include PTX instructions for the specified + CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >= + the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with + CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to + provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on + those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better + off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6, + "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but + "8.0 8.6" would be better. + + Note that while it's possible to include all supported archs, the more archs get included the + slower the building process will be, as it will build a separate kernel image for each arch. + ''' library_dirs = kwargs.get('library_dirs', []) library_dirs += library_paths(cuda=True) @@ -1496,16 +1525,24 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: # If not given, determine what's best for the GPU / CUDA version that can be found if not _arch_list: - capability = torch.cuda.get_device_capability() - supported_sm = [int(arch.split('_')[1]) - for arch in torch.cuda.get_arch_list() if 'sm_' in arch] - max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) - # Capability of the device may be higher than what's supported by the user's - # NVCC, causing compilation error. User's NVCC is expected to match the one - # used to build pytorch, so we use the maximum supported capability of pytorch - # to clamp the capability. - capability = min(max_supported_sm, capability) - arch_list = [f'{capability[0]}.{capability[1]}'] + arch_list = [] + # the assumption is that the extension should run on any of the currently visible cards, + # which could be of different types - therefore all archs for visible cards should be included + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + supported_sm = [int(arch.split('_')[1]) + for arch in torch.cuda.get_arch_list() if 'sm_' in arch] + max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) + # Capability of the device may be higher than what's supported by the user's + # NVCC, causing compilation error. User's NVCC is expected to match the one + # used to build pytorch, so we use the maximum supported capability of pytorch + # to clamp the capability. + capability = min(max_supported_sm, capability) + arch = f'{capability[0]}.{capability[1]}' + if arch not in arch_list: + arch_list.append(arch) + arch_list = sorted(arch_list) + arch_list[-1] += '+PTX' else: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';') diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 1eb60c81f7d0..a46d01797f16 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -498,6 +498,7 @@ def __init__(self, loader: DataLoader) -> None: self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 + self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__) def __iter__(self) -> '_BaseDataLoaderIter': return self @@ -514,22 +515,23 @@ def _next_data(self): raise NotImplementedError def __next__(self) -> Any: - if self._sampler_iter is None: - self._reset() - data = self._next_data() - self._num_yielded += 1 - if self._dataset_kind == _DatasetKind.Iterable and \ - self._IterableDataset_len_called is not None and \ - self._num_yielded > self._IterableDataset_len_called: - warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " - "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, - self._num_yielded) - if self._num_workers > 0: - warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " - "IterableDataset replica at each worker. Please see " - "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") - warnings.warn(warn_msg) - return data + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + self._reset() + data = self._next_data() + self._num_yielded += 1 + if self._dataset_kind == _DatasetKind.Iterable and \ + self._IterableDataset_len_called is not None and \ + self._num_yielded > self._IterableDataset_len_called: + warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, + self._num_yielded) + if self._num_workers > 0: + warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") + warnings.warn(warn_msg) + return data next = __next__ # Python 2 compatibility diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index cb67625df518..e048b54a462c 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -67,6 +67,10 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank