diff --git a/.circleci/cimodel/data/simple/util/versions.py b/.circleci/cimodel/data/simple/util/versions.py index 3c9186df13aa..53d3a837248c 100644 --- a/.circleci/cimodel/data/simple/util/versions.py +++ b/.circleci/cimodel/data/simple/util/versions.py @@ -29,3 +29,6 @@ def __init__(self, major, minor): self.minor = minor super().__init__([self.major, self.minor], "cuda") + + def __str__(self): + return f"{self.major}.{self.minor}" diff --git a/.circleci/cimodel/data/windows_build_definitions.py b/.circleci/cimodel/data/windows_build_definitions.py index dea78411addb..c0e828eaab5e 100644 --- a/.circleci/cimodel/data/windows_build_definitions.py +++ b/.circleci/cimodel/data/windows_build_definitions.py @@ -86,10 +86,11 @@ def gen_tree(self): props_dict["executor"] = "windows-with-nvidia-gpu" props_dict["cuda_version"] = ( - miniutils.quote(str(self.cuda_version.major)) + miniutils.quote(str(self.cuda_version)) if self.cuda_version else "cpu" ) + props_dict["name"] = "_".join(name_parts) return [{key_name: props_dict}] diff --git a/.circleci/config.yml b/.circleci/config.yml index 8bdfb3c9c7bd..cdd66830986f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -325,7 +325,7 @@ pytorch_windows_params: &pytorch_windows_params default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -675,7 +675,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -737,7 +737,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -8077,7 +8077,7 @@ workflows: - postnightly - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" name: pytorch_windows_vs2019_py36_cuda10.1_build python_version: "3.6" use_cuda: "1" @@ -8086,7 +8086,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" executor: windows-with-nvidia-gpu name: pytorch_windows_vs2019_py36_cuda10.1_test1 python_version: "3.6" @@ -8099,7 +8099,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" executor: windows-with-nvidia-gpu name: pytorch_windows_vs2019_py36_cuda10.1_test2 python_version: "3.6" @@ -8112,7 +8112,7 @@ workflows: vc_year: "2019" - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" name: pytorch_windows_vs2019_py36_cuda11.1_build python_version: "3.6" use_cuda: "1" @@ -8121,7 +8121,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" executor: windows-with-nvidia-gpu filters: branches: @@ -8140,7 +8140,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" executor: windows-with-nvidia-gpu filters: branches: @@ -8204,7 +8204,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" filters: branches: only: diff --git a/.circleci/scripts/windows_cuda_install.sh b/.circleci/scripts/windows_cuda_install.sh index 8d615b674aa0..04a4c2ed43ff 100644 --- a/.circleci/scripts/windows_cuda_install.sh +++ b/.circleci/scripts/windows_cuda_install.sh @@ -1,13 +1,11 @@ #!/bin/bash set -eux -o pipefail -if [[ "$CUDA_VERSION" == "10" ]]; then - cuda_complete_version="10.1" +if [[ "$CUDA_VERSION" =~ ^10.* ]]; then cuda_installer_name="cuda_10.1.243_426.00_win10" msbuild_project_dir="CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions" cuda_install_packages="nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -elif [[ "$CUDA_VERSION" == "11" ]]; then - cuda_complete_version="11.1" +elif [[ "$CUDA_VERSION" =~ ^11.* ]]; then cuda_installer_name="cuda_11.1.0_456.43_win10" msbuild_project_dir="visual_studio_integration/CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions" cuda_install_packages="nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" @@ -16,7 +14,7 @@ else exit 1 fi -if [[ "${CUDA_VERSION}" != "10" && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then +if [[ "$CUDA_VERSION" =~ ^11.* && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then cuda_install_packages="${cuda_install_packages} Display.Driver" fi @@ -48,7 +46,7 @@ then export NVTOOLSEXT_PATH="C:\\Program Files\\NVIDIA Corporation\\NvToolsExt\\" fi -if ! ls "/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${cuda_complete_version}/bin/nvcc.exe" +if ! ls "/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${CUDA_VERSION}/bin/nvcc.exe" then echo "CUDA installation failed" mkdir -p /c/w/build-results diff --git a/.circleci/scripts/windows_cudnn_install.sh b/.circleci/scripts/windows_cudnn_install.sh index 529710af79b2..62f54615677e 100644 --- a/.circleci/scripts/windows_cudnn_install.sh +++ b/.circleci/scripts/windows_cudnn_install.sh @@ -1,12 +1,10 @@ #!/bin/bash set -eux -o pipefail -if [[ "$CUDA_VERSION" == "10" ]]; then - cuda_complete_version="10.1" - cudnn_installer_name="cudnn-10.1-windows10-x64-v7.6.4.38" -elif [[ "$CUDA_VERSION" == "11" ]]; then - cuda_complete_version="11.1" - cudnn_installer_name="cudnn-11.1-windows-x64-v8.0.5.39" +if [[ "$CUDA_VERSION" =~ ^10.* ]]; then + cudnn_installer_name="cudnn-${CUDA_VERSION}-windows10-x64-v7.6.4.38" +elif [[ "$CUDA_VERSION" =~ ^11.* ]]; then + cudnn_installer_name="cudnn-${CUDA_VERSION}-windows-x64-v8.0.5.39" else echo "CUDNN for CUDA_VERSION $CUDA_VERSION is not supported yet" exit 1 @@ -16,6 +14,6 @@ cudnn_installer_link="https://ossci-windows.s3.amazonaws.com/${cudnn_installer_n curl --retry 3 -O $cudnn_installer_link 7z x ${cudnn_installer_name}.zip -ocudnn -cp -r cudnn/cuda/* "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${cuda_complete_version}/" +cp -r cudnn/cuda/* "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${CUDA_VERSION}/" rm -rf cudnn rm -f ${cudnn_installer_name}.zip diff --git a/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml b/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml index e031e01ba846..c912a4fb690b 100644 --- a/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml +++ b/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml @@ -59,7 +59,7 @@ pytorch_windows_params: &pytorch_windows_params default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" diff --git a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml index 8d8036ea9523..aa0e2d2c5581 100644 --- a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml @@ -237,7 +237,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -299,7 +299,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 1f1f174e992e..8e9afd5c9bc3 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -11,6 +11,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" +export LANG=C.UTF-8 + if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then export PYTORCH_TEST_WITH_SLOW=1 export PYTORCH_TEST_SKIP_FAST=1 diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index f41e5f7fcd1b..7165f75a0e41 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -37,33 +37,19 @@ if "%VC_VERSION%" == "" ( @echo on popd -if "%CUDA_VERSION%" == "9" goto cuda_build_9 -if "%CUDA_VERSION%" == "10" goto cuda_build_10 -if "%CUDA_VERSION%" == "11" goto cuda_build_11 -goto cuda_build_end +if not "%USE_CUDA%"=="1" goto cuda_build_end -:cuda_build_9 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION% -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2 -set CUDA_PATH_V9_2=%CUDA_PATH% +rem version transformer, for example 10.1 to 10_1. +set VERSION_SUFFIX=%CUDA_VERSION:.=_% +set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH% -goto cuda_build_common - -:cuda_build_10 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1 -set CUDA_PATH_V10_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_11 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1 -set CUDA_PATH_V11_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_common +set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 +set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% +set CUDNN_ROOT_DIR=%CUDA_PATH% +set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt +set PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH% set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat index d4821c1b1a8d..ab102a0ea423 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat @@ -1,9 +1,9 @@ -if "%CUDA_VERSION%" == "9" set CUDA_SUFFIX=cuda92 -if "%CUDA_VERSION%" == "10" set CUDA_SUFFIX=cuda101 -if "%CUDA_VERSION%" == "11" set CUDA_SUFFIX=cuda110 +rem remove dot in cuda_version, fox example 11.1 to 111 +set VERSION_SUFFIX=%CUDA_VERSION:.=% +set CUDA_SUFFIX=cuda%VERSION_SUFFIX% if "%CUDA_SUFFIX%" == "" ( - echo unknown CUDA version, please set `CUDA_VERSION` to 9, 10 or 11. + echo unknown CUDA version, please set `CUDA_VERSION` higher than 9.2 exit /b 1 ) diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat index e3625ae75e9e..a052a1b67d59 100644 --- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -46,33 +46,13 @@ if %errorlevel% neq 0 ( exit /b %errorlevel% ) set DISTUTILS_USE_SDK=1 -if "%CUDA_VERSION%" == "9" goto cuda_build_9 -if "%CUDA_VERSION%" == "10" goto cuda_build_10 -if "%CUDA_VERSION%" == "11" goto cuda_build_11 -goto cuda_build_end +if not "%USE_CUDA%"=="1" goto cuda_build_end -:cuda_build_9 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION% -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2 -set CUDA_PATH_V9_2=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_10 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1 -set CUDA_PATH_V10_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_11 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1 -set CUDA_PATH_V11_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_common +rem version transformer, for example 10.1 to 10_1. +set VERSION_SUFFIX=%CUDA_VERSION:.=_% +set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH% set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% diff --git a/BUILD.bazel b/BUILD.bazel index 5da8edc2c34e..ec5111c5104d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -339,7 +339,8 @@ filegroup( "aten/src/ATen/cuda/CUDABlas.cpp", "aten/src/ATen/cuda/CUDASolver.cpp", "aten/src/ATen/cuda/CUDAContext.cpp", - "aten/src/ATen/cuda/CUDAGenerator.cpp", + "aten/src/ATen/cuda/CUDAGeneratorImpl.cpp", + "aten/src/ATen/cuda/CUDAGraph.cpp", "aten/src/ATen/cuda/CuSparseHandlePool.cpp", "aten/src/ATen/cuda/CublasHandlePool.cpp", "aten/src/ATen/cuda/CusolverDnHandlePool.cpp", diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f2981c0dbb37..51e9d1382808 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -674,7 +674,7 @@ ccache -M 25Gi ``` To check this is working, do two clean builds of pytorch in a row. The second -build should be substantially and noticeably faster than the first build. +build should be substantially and noticeably faster than the first build. If this doesn't seem to be the case, check that each of the symlinks above actually link to your installation of `ccache`. For example, if you followed the first option and installed `ccache` from source on a Linux machine, running `readlink -e $(which g++)` should return `~/ccache/bin/ccache`. #### Use a faster linker diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp index e4bb4c083160..9cc71f117d93 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp @@ -90,13 +90,13 @@ class PytorchJni : public facebook::jni::HybridClass { #endif #ifdef TRACE_ENABLED - static bool onFunctionEnter( + static std::unique_ptr onFunctionEnter( const at::RecordFunction& fn) { Trace::beginSection(fn.name().str()); - return true; + return nullptr; } - static void onFunctionExit(const at::RecordFunction&) { + static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) { Trace::endSection(); } #endif diff --git a/aten/src/ATen/CUDAGeneratorImpl.h b/aten/src/ATen/CUDAGeneratorImpl.h index ec83128c7013..9a9febd01f8e 100644 --- a/aten/src/ATen/CUDAGeneratorImpl.h +++ b/aten/src/ATen/CUDAGeneratorImpl.h @@ -131,8 +131,8 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl { uint64_t seed() override; void set_philox_offset_per_thread(uint64_t offset); uint64_t philox_offset_per_thread(); - void graph_prologue(int64_t* offset_extragraph); - uint64_t graph_epilogue(); + void capture_prologue(int64_t* offset_extragraph); + uint64_t capture_epilogue(); PhiloxCudaState philox_cuda_state(uint64_t increment); // Temporarily accommodates call sites that use philox_engine_inputs. @@ -147,6 +147,7 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl { uint64_t philox_offset_per_thread_ = 0; int64_t* offset_extragraph_; uint32_t offset_intragraph_ = 0; + bool graph_expects_this_gen_ = false; }; namespace cuda { diff --git a/aten/src/ATen/cpu/vec256/missing_vst1_neon.h b/aten/src/ATen/cpu/vec256/missing_vst1_neon.h index dbb2ba479f85..dffd5dbb862e 100644 --- a/aten/src/ATen/cpu/vec256/missing_vst1_neon.h +++ b/aten/src/ATen/cpu/vec256/missing_vst1_neon.h @@ -4,6 +4,5 @@ __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_f32_x2 (float32_t * __a, float32x4x2_t val) { - asm ("st1 {%S0.4s - %T0.4s}, [%1]" :: "w" (val), "r"(__a) :); + asm ("st1 {%S1.4s - %T1.4s}, [%2]" : "=m" (*__a) : "w" (val), "r"(__a) : "memory"); } - diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index 58a677dc5de0..43389fe61583 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -25,7 +25,7 @@ static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { __m256i lo = _mm256_castps_si256(a); __m256i hi = _mm256_castps_si256(b); - __m256i nan = _mm256_set1_epi32(0x7fc0); + __m256i nan = _mm256_set1_epi32(0xffff); __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); __m256i ones = _mm256_set1_epi32(0x1); diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index f0db9014163a..8a5e4f48e0c0 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -84,7 +84,7 @@ Generator createCUDAGenerator(DeviceIndex device_index) { */ CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index) : c10::GeneratorImpl{Device(DeviceType::CUDA, device_index), - DispatchKeySet(c10::DispatchKey::CUDA)} { + DispatchKeySet(c10::DispatchKey::CUDA)} { at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl"); } @@ -101,20 +101,18 @@ void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { } #define CAPTURE_DEFAULT_GENS_MSG \ -"Non-default (user-constructed) CUDA RNG generators cannot be used " \ -"in regions captured by CUDA graphs. " \ -"If you need a non-default CUDA generator in a captured region, " \ -"please file an issue." +"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \ +"generator on the device that's current when capture begins. " \ +"If you need a non-default (user-supplied) generator, or a generator on another " \ +"device, please file an issue." /** * Gets the current seed of CUDAGeneratorImpl. */ uint64_t CUDAGeneratorImpl::current_seed() const { - TORCH_CHECK((at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) || - ((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index())), - CAPTURE_DEFAULT_GENS_MSG); + // Debatable if current_seed() should be allowed in captured regions. + // Conservatively disallow it for now. + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed"); return seed_; } @@ -151,25 +149,21 @@ uint64_t CUDAGeneratorImpl::philox_offset_per_thread() { } /** - * Prepares this instance for a cuda graph capture region. + * Called by CUDAGraph to prepare this instance for a graph capture region. * offset_extragraph is the initial offset at the start of the graphed region. * offset_intragraph tracks the offset in the graphed region. */ -void CUDAGeneratorImpl::graph_prologue(int64_t* offset_extragraph) { - TORCH_CHECK((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()), - CAPTURE_DEFAULT_GENS_MSG); +void CUDAGeneratorImpl::capture_prologue(int64_t* offset_extragraph) { offset_extragraph_ = offset_extragraph; offset_intragraph_ = 0; + graph_expects_this_gen_ = true; } /** - * Finalizes a cuda graph capture region for this instance. + * Called by CUDAGraph to finalize a graph capture region for this instance. */ -uint64_t CUDAGeneratorImpl::graph_epilogue() { - TORCH_CHECK((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()), - CAPTURE_DEFAULT_GENS_MSG); +uint64_t CUDAGeneratorImpl::capture_epilogue() { + graph_expects_this_gen_ = false; return offset_intragraph_; } @@ -187,7 +181,7 @@ uint64_t CUDAGeneratorImpl::graph_epilogue() { * it intends to generate. * * Increment should be at least the number of curand() random numbers used in - * each thread. It is the user's responsibility to make sure that the increment + * each thread. It is the user's responsibility to make sure the increment * for philox is never smaller than the number of curand() calls. Increment * value > the number of curand() calls won't harm but anything less would mean * that you would be reusing random values from previous calls. @@ -196,17 +190,20 @@ uint64_t CUDAGeneratorImpl::graph_epilogue() { */ PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) { if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - TORCH_CHECK((void*)this == - (void*)&at::cuda::detail::getDefaultCUDAGenerator(device_.index()), + TORCH_CHECK(graph_expects_this_gen_, + "philox_cuda_state for an unexpected CUDA generator used during capture. " CAPTURE_DEFAULT_GENS_MSG); uint32_t offset = this->offset_intragraph_; TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <= - std::numeric_limits::max() - increment); + std::numeric_limits::max() - increment); this->offset_intragraph_ += increment; return PhiloxCudaState(this->seed_, this->offset_extragraph_, offset); } else { + TORCH_CHECK(!graph_expects_this_gen_, + "CUDA generator expects graph capture to be underway, " + "but the current stream is not capturing."); uint64_t offset = this->philox_offset_per_thread_; this->philox_offset_per_thread_ += increment; return PhiloxCudaState(this->seed_, offset); diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp new file mode 100644 index 000000000000..74cc5ca09793 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -0,0 +1,168 @@ +#include +#include +#include +#include + +namespace at { +namespace cuda { + +/** + * Note [CUDA Graph Wrapper Class] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Q: Why do we need graph capture and launch bindings in Pytorch? + * Why can't they live in a user extension, for example? + * + * A1: Convenience. + * A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with + * CPU statefulness) need cooperation from the capture and replay bindings + * (see Note [CUDA Graph-safe RNG states] in CUDAGeneratorImpl.h). + * + * We can't expect users to know about this cooperation. If users write capture + * bindings naively in an extension, they likely won't interact with the native + * ops properly. Their graphs would yield invalid numerics on replay. + */ + +CUDAGraph::CUDAGraph() + // CUDAStreams may not be default-constructed. + : capture_stream_(at::cuda::getCurrentCUDAStream()) { +#if CUDA_VERSION < 11000 + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::capture_begin() { +#if CUDA_VERSION >= 11000 + TORCH_CHECK(!has_graph_exec_, + "This CUDAGraph instance already owns a captured graph. " + "To capture a new graph, create a new instance."); + + // For now, a CUDAGraph instance only accommodates the default generator on the device that's + // current when capture begins. If any op in the captured region uses a non-default generator, + // or a generator on another device, the offending generator will throw an error. + // These restrictions simplify CUDAGraph, but could be relaxed in the future: + // in principle, the underlying Cuda calls do permit cross-device ops to be captured. + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + + auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong); + offset_extragraph_ = at::empty({1}, options); + + gen->capture_prologue(offset_extragraph_.data_ptr()); + + auto stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(), + "CUDA graphs must be captured on a non-default stream. " + "(However, after capture, it's ok to replay them on the " + "default stream.)"); + + capture_stream_ = stream; + capture_gen_ = gen; + + // cudaStreamCaptureModeGlobal is the most conservative option to + // prevent potentially unsafe CUDA API calls during capture. See + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); + + // Stashes the current graph's uuid. + cudaStreamCaptureStatus status; + AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id_)); + TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::capture_end() { +#if CUDA_VERSION >= 11000 + auto stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(stream == capture_stream_, + "Capture must end on the same stream it began on."); + + AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); + TORCH_CHECK(graph_ != NULL, "Invalid capture."); + has_graph_ = true; + + // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, + // who prefer not to report error message through these arguments moving forward + // (they prefer return value, or errors on api calls internal to the capture) + AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); + has_graph_exec_ = true; + + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + TORCH_CHECK(gen == capture_gen_, + "Default CUDA RNG generator on current device at capture end " + "is different from default generator on current device " + "when capture began"); + wholegraph_increment_ = gen->capture_epilogue(); + + // Now that we've instantiated graph_ into graph_exec_, + // we don't need graph_ anymore. + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + has_graph_ = false; +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::replay() { +#if CUDA_VERSION >= 11000 + TORCH_CHECK(has_graph_exec_, + "Called CUDAGraph::replay without a preceding successful capture."); + + { + c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; + + // Just like any RNG consumer kernel! + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + PhiloxCudaState rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_); + } + offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val)); + + // graph_exec_ may be replayed in any stream. + AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream())); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::reset() { +#if CUDA_VERSION >= 11000 + // I'd prefer these checks throw exceptions, not print warnings, + // but the destructor calls reset(), and at least one CI build + // refuses to compile with a throwing destructor. + // + // Instead of calling reset() in the destructor to clean up, I could + // call reset() in the __del__ method of a thin Python wrapper, + // in which case reset would be allowed to throw exceptions. + // But Stackoverflow does not like user-defined __del__. + // __del__ prevents Graph instances from EVER being garbage collected + // if they participate in a reference cycle. + // And exceptions thrown in __del__ only print a warning anyway. + // + // Calling reset() in the C++ destructor, with warnings instead of exceptions + // if calls fail, is the compromise we chose. + if (has_graph_) { + C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_)); + } + if (has_graph_exec_) { + C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_)); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +CUDAGraph::~CUDAGraph() { + reset(); +} + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h new file mode 100644 index 000000000000..387271715055 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -0,0 +1,43 @@ +#include +#include +#include +#include + +namespace at { +namespace cuda { + +struct TORCH_CUDA_API CUDAGraph { + CUDAGraph(); + ~CUDAGraph(); + + void capture_begin(); + void capture_end(); + void replay(); + void reset(); + + protected: +#if CUDA_VERSION >= 11000 + cudaGraph_t graph_ = NULL; + cudaGraphExec_t graph_exec_ = NULL; +#endif + + // internal states for error checking + bool has_graph_ = false; + bool has_graph_exec_ = false; + + // uuid, retrieved from Cuda + unsigned long long id_; + + // Stream on which capture began + at::cuda::CUDAStream capture_stream_; + + // Default generator on device where capture began + at::CUDAGeneratorImpl* capture_gen_; + + // RNG state trackers + at::Tensor offset_extragraph_; + uint64_t wholegraph_increment_; +}; + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 28b9738034e7..00424ab83ba0 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -163,7 +163,9 @@ bool CUDAHooks::hasPrimaryContext(int64_t device_index) const { TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), "hasPrimaryContext expects a valid device index, but got device_index=", device_index); unsigned int ctx_flags; - int ctx_is_active; + // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird + // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero. + int ctx_is_active = 0; AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active)); return ctx_is_active == 1; } diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index 9802797874b9..9778aa035cb1 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -1,7 +1,6 @@ #include #include -#include -#include +#include namespace at { @@ -9,295 +8,66 @@ namespace native { namespace { - inline int start_index(int a, int b, int c) { - return (int)std::floor((float)(a * c) / b); - } - - inline int end_index(int a, int b, int c) { - return (int)std::ceil((float)((a + 1) * c) / b); - } - - template - static void adaptive_avg_pool2d_single_out_frame( - scalar_t *input_p, - scalar_t *output_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) - { - at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { - for (auto d = start; d < end; d++) - { - /* loop over output */ - int64_t oh, ow; - for(oh = 0; oh < osizeH; oh++) - { - int istartH = start_index(oh, osizeH, isizeH); - int iendH = end_index(oh, osizeH, isizeH); - int kH = iendH - istartH; - - for(ow = 0; ow < osizeW; ow++) - { - int istartW = start_index(ow, osizeW, isizeW); - int iendW = end_index(ow, osizeW, isizeW); - int kW = iendW - istartW; - - /* local pointers */ - scalar_t *ip = input_p + d*istrideD + istartH*istrideH + istartW*istrideW; - scalar_t *op = output_p + d*osizeH*osizeW + oh*osizeW + ow; - - /* compute local average: */ - scalar_t sum = 0; - int ih, iw; - for(ih = 0; ih < kH; ih++) - { - for(iw = 0; iw < kW; iw++) - { - scalar_t val = *(ip + ih*istrideH + iw*istrideW); - sum += val; - } - } - - /* set output to local average */ - *op = sum / kW / kH; - } - } - } - }); - } - - template - void adaptive_avg_pool2d_out_frame( - scalar_t *input_p, - scalar_t *output_p, - int64_t sizeB, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideB, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) - { - at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - adaptive_avg_pool2d_single_out_frame( - input_p + b * istrideB, - output_p + b * sizeD * osizeH * osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - }); - } - void adaptive_avg_pool2d_out_cpu_template( at::Tensor& output, at::Tensor const& input, IntArrayRef output_size) { TORCH_CHECK(output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); - for (int64_t i = 0; i < input.ndimension(); i++) { + int64_t ndim = input.ndimension(); + for (int64_t i = 0; i < ndim; i++) { TORCH_CHECK(input.size(i) > 0, "adaptive_avg_pooling2d(): expected input to have non-empty spatial dimensions, " "but input has sizes ", input.sizes(), " with dimension ", i, " being " "empty"); } - TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), + TORCH_CHECK((ndim == 3 || ndim == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); + TORCH_CHECK(input.dtype() == output.dtype(), + "expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype()); - /* sizes */ - int64_t sizeD = input.size(-3); - int64_t isizeH = input.size(-2); - int64_t isizeW = input.size(-1); - /* strides */ - int64_t istrideD = input.stride(-3); - int64_t istrideH = input.stride(-2); - int64_t istrideW = input.stride(-1); - - auto osizeH = output_size[0]; - auto osizeW = output_size[1]; - - /* resize output */ - if (input.ndimension() == 3 || input.size(-4) == 1) - { - if (input.ndimension() == 3) { - output.resize_({sizeD, osizeH, osizeW}); - } else { - output.resize_({1, sizeD, osizeH, osizeW}); - } - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - adaptive_avg_pool2d_single_out_frame( - input_data, - output_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - ); - } - else - { - int64_t sizeB = input.size(-4); - output.resize_({sizeB, sizeD, osizeH, osizeW}); - int64_t istrideB = input.stride(-4); + int64_t channels = input.size(-3); + int64_t input_height = input.size(-2); + int64_t input_width = input.size(-1); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - adaptive_avg_pool2d_out_frame( - input_data, - output_data, - sizeB, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideB, - istrideD, - istrideH, istrideW); - }); + if (ndim == 3) { + output.resize_({channels, output_height, output_width}); + } else { + int64_t nbatch = input.size(0); + output.resize_({nbatch, channels, output_height, output_width}, input.suggest_memory_format()); } - } - - template - static void adaptive_avg_pool2d_backward_single_out_frame( - scalar_t *gradInput_p, - scalar_t *gradOutput_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) - { - at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { - for (auto d = start; d < end; d++) - { - scalar_t *gradInput_p_d = gradInput_p + d*isizeW*isizeH; - scalar_t *gradOutput_p_d = gradOutput_p + d*osizeW*osizeH; - - /* calculate average */ - int64_t oh, ow; - for(oh = 0; oh < osizeH; oh++) - { - int istartH = start_index(oh, osizeH, isizeH); - int iendH = end_index(oh, osizeH, isizeH); - int kH = iendH - istartH; - for(ow = 0; ow < osizeW; ow++) - { - - int istartW = start_index(ow, osizeW, isizeW); - int iendW = end_index(ow, osizeW, isizeW); - int kW = iendW - istartW; - - scalar_t grad_delta = gradOutput_p_d[oh*osizeW +ow] / kH / kW; - - int ih, iw; - for(ih = istartH; ih < iendH; ih++) - { - for(iw = istartW; iw < iendW; iw++) - { - /* update gradient */ - gradInput_p_d[ih*isizeW + iw] += grad_delta; - } - } - } - } - } - }); - } - - template - void adaptive_avg_pool2d_backward_out_frame( - scalar_t *gradInput_p, - scalar_t *gradOutput_p, - int64_t sizeB, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) - { - at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - scalar_t *gradInput_p_d = gradInput_p + b * sizeD * isizeW * isizeH; - scalar_t *gradOutput_p_d = gradOutput_p + b * sizeD * osizeW * osizeH; - adaptive_avg_pool2d_backward_single_out_frame( - gradInput_p_d, - gradOutput_p_d, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - }); + adaptive_avg_pool2d_kernel(kCPU, output, input, output_size); } Tensor& adaptive_avg_pool2d_backward_out_cpu_template( - Tensor& gradInput, - const Tensor& gradOutput_, + Tensor& grad_input, + const Tensor& grad_output, const Tensor& input) { - /* sizes */ - int sizeD = input.size(-3); - int isizeH = input.size(-2); - int isizeW = input.size(-1); - int osizeH = gradOutput_.size(-2); - int osizeW = gradOutput_.size(-1); - - /* get contiguous gradOutput */ - auto gradOutput = gradOutput_.contiguous(); + int64_t ndim = grad_output.ndimension(); + for (int64_t i = 0; i < ndim; i++) { + TORCH_CHECK(grad_output.size(i) > 0, + "adaptive_avg_pooling2d_backward(): expected grad_output to have non-empty spatial dimensions, " + "but grad_output has sizes ", grad_output.sizes(), " with dimension ", i, " being " + "empty"); + } - /* backprop */ - if (input.ndimension() == 3 || input.size(-4) == 1) - { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data_ptr(); - scalar_t *gradOutput_data = gradOutput.data_ptr(); + TORCH_CHECK((ndim == 3 || ndim == 4), + "non-empty 3D or 4D (batch mode) tensor expected for grad_output"); + TORCH_CHECK(input.dtype() == grad_output.dtype(), + "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype()); + TORCH_CHECK(input.dtype() == grad_input.dtype(), + "expected dtype ", input.dtype(), " for `grad_input` but got dtype ", grad_input.dtype()); - adaptive_avg_pool2d_backward_single_out_frame( - gradInput_data, gradOutput_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); - } - else - { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data_ptr(); - scalar_t *gradOutput_data = gradOutput.data_ptr(); - int64_t sizeB = input.size(-4); + grad_input.resize_(input.sizes(), input.suggest_memory_format()); + grad_input.zero_(); - adaptive_avg_pool2d_backward_out_frame( - gradInput_data, gradOutput_data, - sizeB, sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); - } - return gradInput; + adaptive_avg_pool2d_backward_kernel(kCPU, grad_input, grad_output); + return grad_input; } } // namespace @@ -346,25 +116,27 @@ namespace { } Tensor& adaptive_avg_pool2d_backward_out_cpu( - Tensor& gradInput, - const Tensor& gradOutput, + Tensor& grad_input, + const Tensor& grad_output, const Tensor& input) { - gradInput.resize_as_(input); adaptive_avg_pool2d_backward_out_cpu_template( - gradInput, gradOutput, input); - return gradInput; + grad_input, grad_output, input); + return grad_input; } Tensor adaptive_avg_pool2d_backward_cpu( - const Tensor& gradOutput, + const Tensor& grad_output, const Tensor& input) { - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_input = at::empty({0}, input.options()); adaptive_avg_pool2d_backward_out_cpu_template( - gradInput, gradOutput, input); - return gradInput; + grad_input, grad_output, input); + return grad_input; } +DEFINE_DISPATCH(adaptive_avg_pool2d_kernel); +DEFINE_DISPATCH(adaptive_avg_pool2d_backward_kernel); + } // at::native } // at diff --git a/aten/src/ATen/native/AdaptivePooling.h b/aten/src/ATen/native/AdaptivePooling.h new file mode 100644 index 000000000000..29b2fd1c94c9 --- /dev/null +++ b/aten/src/ATen/native/AdaptivePooling.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel); +DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel); + +static inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (int64_t)std::floor((float)(a * c) / b); +} + +static inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return (int64_t)std::ceil((float)((a + 1) * c) / b); +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp index 14c126f77bdf..e6301e682d77 100644 --- a/aten/src/ATen/native/PixelShuffle.cpp +++ b/aten/src/ATen/native/PixelShuffle.cpp @@ -4,31 +4,51 @@ #include #include +#include #include namespace at { namespace native { Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { - AT_ASSERTM(self.dim() == 4, - "pixel_shuffle expects 4D input, but got input with sizes ",self.sizes()); - int64_t b = self.size(0); - int64_t c = self.size(1); - int64_t h = self.size(2); - int64_t w = self.size(3); + TORCH_CHECK(self.dim() >= 3, + "pixel_shuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + // Format: (B1, ..., Bn), C, H, W + int64_t c = self.size(-3); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + const auto NUM_NON_BATCH_DIMS = 3; + const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS; + int64_t upscale_factor_squared = upscale_factor * upscale_factor; - AT_ASSERTM(c % upscale_factor_squared == 0, - "pixel_shuffle expects input channel to be divisible by square of " - "upscale_factor, but got input with sizes ", self.sizes(), - ", upscale_factor=", upscale_factor, - ", and self.size(1)=", c, " is not divisible by ", upscale_factor_squared); + TORCH_CHECK(c % upscale_factor_squared == 0, + "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " + "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared); int64_t oc = c / upscale_factor_squared; int64_t oh = h * upscale_factor; int64_t ow = w * upscale_factor; - auto input_reshaped = self.reshape({b, oc, upscale_factor, upscale_factor, h, w}); - return input_reshaped.permute({0 /* b */, 1 /* oc */, 4 /* h */, 2 /* 1st upscale_factor */, 5 /* w */, 3 /* 2nd upscale_factor */}) - .reshape({b, oc, oh, ow}); + // First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor). + // This allows shuffling to be done next by permuting dims. + std::vector expanded_shape(self.sizes().begin(), last_batch_dim); + expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); + const auto input_expanded = self.reshape(expanded_shape); + + // Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims. + std::vector permutation(self.sizes().begin(), last_batch_dim); + // std::iota is used to maintain the batch dims within the permutation. + // Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6. + std::iota(permutation.begin(), permutation.end(), -expanded_shape.size()); + permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */, + -3 /* 2nd upscale_factor */}); + const auto input_permuted = input_expanded.permute(permutation); + + // Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh) + // and (w, upscale_factor) -> a single dim (ow). + std::vector final_shape(self.sizes().begin(), last_batch_dim); + final_shape.insert(final_shape.end(), {oc, oh, ow}); + return input_permuted.reshape(final_shape); } }} // namespace at::native diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 65d67629fa9f..4ae2ee326b88 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -19,12 +19,6 @@ namespace at { namespace native { -// Common code for all FFT functions -static inline Tensor _fft( - const Tensor &self, int64_t signal_ndim, bool complex_input, - const bool complex_output, bool inverse, IntArrayRef signal_sizes, - fft_norm_mode normalization, bool onesided); - namespace { // Promote inputs to FFT functions @@ -416,139 +410,6 @@ Tensor fft_ifftshift(const Tensor& x, c10::optional dim_opt) { } -// This is a pass-through wrapper function that does the size check and -// inferences. The actual forward implementation function is called -// at::_fft_with_size which dispatches to _fft_cufft (CUDA) or _fft_mkl (CPU). -static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim, - const bool complex_input, const bool complex_output, - const bool inverse, IntArrayRef signal_sizes, - const fft_norm_mode normalization, const bool onesided) { - - TORCH_CHECK(signal_ndim >= 1 && signal_ndim <= 3, - "Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=", - signal_ndim); - TORCH_CHECK(at::isFloatingType(self.scalar_type()), - "Expected an input tensor of floating types, but got input=", - self.toString(), self.sizes()); - - auto signal_tensor_ndim = signal_ndim + static_cast(complex_input); // add complex dim - if (self.dim() < signal_tensor_ndim) { - std::ostringstream ss; - ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor " - << "of at least " << signal_tensor_ndim << "D"; - if (complex_input) { - ss << " (complex input adds an extra dimension)"; - } - ss << ", but got input=" << self.toString() << self.sizes(); - AT_ERROR(ss.str()); - } - - auto self_shape = self.sizes(); - auto batch_ndim = self.dim() - signal_tensor_ndim; - - Tensor input = self; - // flatten the batch dims - if (batch_ndim == 0) { - // slightly faster path for non-batch mode - input = input.unsqueeze(0); - } else if (batch_ndim > 1) { - std::vector flatten_input_shape(signal_tensor_ndim + 1); - std::copy(self_shape.begin() + batch_ndim, self_shape.end(), flatten_input_shape.begin() + 1); - flatten_input_shape[0] = -1; - input = input.reshape(flatten_input_shape); - - } - - // now we assume that input is batched as [ B x signal_dims... ] - - if (complex_input) { - TORCH_CHECK(input.size(signal_ndim + 1) == 2, - "Expected an input tensor with a last dimension of size 2 " - "representing real + imaginary components, but got input ", - self.toString(), self.sizes()); - } - - // build signal_sizes and output_size - TORCH_CHECK(signal_sizes.size() == 0 || static_cast(signal_sizes.size()) == signal_ndim, - "Expected signal_sizes to be empty (default) or of signal_ndim=", - signal_ndim, "D, but got signal_sizes=", signal_sizes); - std::vector output_sizes(signal_ndim + 1 + static_cast(complex_output)); - output_sizes[0] = input.size(0); // batch size - std::vector checked_signal_sizes(signal_ndim); - for (int64_t i = 0; i < signal_ndim; i++) { - int64_t input_size = input.size(i + 1); - if (i == signal_ndim - 1 && onesided && complex_input && !complex_output) { - // If last dim and complex-to-real onesided, input is only half of - // signal, and we need to infer basing on signal_sizes, if given - // See native/SpectralOpsUtils.h for detailed description. - int64_t inferred_size; - if (signal_sizes.size() > 0) { - inferred_size = infer_ft_complex_to_real_onesided_size(input_size, signal_sizes[i]); - } else { - inferred_size = infer_ft_complex_to_real_onesided_size(input_size); - } - checked_signal_sizes[i] = inferred_size; - output_sizes[i + 1] = inferred_size; - } else { - if (i == signal_ndim - 1 && onesided && !complex_input && complex_output) { - // if last dim and real-to-complex onesided, output should be only - // half of the signal, and we need to infer using input_size - output_sizes[i + 1] = infer_ft_real_to_complex_onesided_size(input_size); - } else { - output_sizes[i + 1] = input_size; - } - checked_signal_sizes[i] = input_size; - TORCH_CHECK(signal_sizes.size() == 0 || signal_sizes[i] == checked_signal_sizes[i], - "Expected given signal_sizes=", signal_sizes," to have same " - "shape with input at signal dimension ", i, ", but got " - "signal_sizes=", signal_sizes, " and input=", self.toString(), - self.sizes()); - } - } - if (complex_output) { - output_sizes[signal_ndim + 1] = 2; - } - - Tensor output = at::_fft_with_size(input, signal_ndim, complex_input, - complex_output, inverse, - checked_signal_sizes, - static_cast(normalization), - onesided, - output_sizes); - - // unflatten the batch dims - if (batch_ndim == 0) { - // slightly faster path for non-batch mode - output = output.squeeze(0); - } else if (batch_ndim > 1) { - auto output_ndim = self.dim() + static_cast(complex_output) - static_cast(complex_input); - std::vector unflatten_output_shape(output_ndim); - std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin()); - std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim); - output = output.reshape(unflatten_output_shape); - } - return output; -} - -// Wrapper to preserve the historic signature of _fft_with_size -// NOTE: This is only used for torchscript backwards compatibility and the new -// signature with normalization modes should be used in all other cases -Tensor _fft_with_size(const Tensor& input, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - bool normalized, bool onesided, - IntArrayRef output_sizes) { - fft_norm_mode norm; - if (normalized) { - norm = fft_norm_mode::by_root_n; - } else { - norm = inverse ? fft_norm_mode::by_n : fft_norm_mode::none; - } - return at::_fft_with_size( - input, signal_ndim, complex_input, complex_output, inverse, - checked_signal_sizes, static_cast(norm), onesided, output_sizes); -} - // We call the following methods via CUDA hooks because they are really only // valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details. int64_t _cufft_get_plan_cache_max_size(int64_t device_index) { @@ -567,36 +428,6 @@ void _cufft_clear_plan_cache(int64_t device_index) { detail::getCUDAHooks().cuFFTClearPlanCache(device_index); } -static Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) { - return _fft(self, signal_ndim, /* complex_input */ true, - /* complex_output */ true, /* inverse */ false, {}, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none, - /* onesided */ false); -} - -static Tensor ifft(const Tensor& self, const int64_t signal_ndim, const bool normalized) { - return _fft(self, signal_ndim, /* complex_input */ true, - /* complex_output */ true, /* inverse */ true, {}, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n, - /* onesided */ false); -} - -static Tensor rfft(const Tensor& self, const int64_t signal_ndim, const bool normalized, - const bool onesided) { - return _fft(self, signal_ndim, /* complex_input */ false, - /* complex_output */ true, /* inverse */ false, {}, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none, - onesided); -} - -static Tensor irfft(const Tensor& self, const int64_t signal_ndim, const bool normalized, - const bool onesided, IntArrayRef signal_sizes) { - return _fft(self, signal_ndim, /* complex_input */ true, - /* complex_output */ false, /* inverse */ true, signal_sizes, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n, - onesided); -} - template static Stream& write_opt(Stream& SS, const optional& value) { if (value) { diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 8b5fdd44d789..f3147bdf78aa 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -471,6 +471,87 @@ Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const T return self.clone(at::MemoryFormat::Preserve).index_add_(dim, index, source); } +// Check that indices fall within dimension array size +// Avoid redispatch call to min/max +template +static void check_indexarray_range( + const IndexType* indices, + int64_t n, + IndexType indexing_axis_dim) { + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + TORCH_CHECK( + 0 <= idx && idx < indexing_axis_dim, + "INDICES element is out of DATA bounds, id=", + idx, + " axis_dim=", + indexing_axis_dim); + } +} + +Tensor & index_select_out_cpu_dim1_( + Tensor & result_contig, const Tensor & self, const Tensor & index_contig) { + + auto self_contig = self.contiguous(); + const caffe2::TypeMeta dataType = self_contig.dtype(); + size_t item_bytesize = dataType.itemsize(); + + auto out = static_cast(result_contig.data_ptr()); + + auto src_base = static_cast(self_contig.data_ptr()); + + auto self_sizes = self_contig.sizes(); + auto outer_dims_product = c10::size_to_dim_(1, self_sizes); + auto block_size = c10::size_from_dim_(2, self_sizes); + auto block_bytesize = block_size * item_bytesize; + + auto src_indexing_axis_dim = self_sizes[1]; + auto src_batch_bytesize = self_sizes[1] * block_bytesize; + auto N = index_contig.numel(); + + auto gathered_batch_bytesize = N * block_bytesize; + + AT_DISPATCH_INDEX_TYPES( + index_contig.scalar_type(), "batch_index_select_compute", [&]() { + + const auto* idxs = index_contig.data_ptr(); + check_indexarray_range(idxs, N, src_indexing_axis_dim); + + // Special-case single-float copy for efficiency + if (self.scalar_type() == ScalarType::Float && block_size == 1) { + for (auto batch = 0; batch < outer_dims_product; ++batch) { + const float* src_floats = + (const float*)(src_base + batch * src_batch_bytesize); + float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); + + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } + dst_floats[i] = src_floats[idx]; + } + } + } else { + // outer_dims_product specifies how many times we repeat inner dimensions, + // so we just iterate over it to cover all outer dimensions. + for (auto batch = 0; batch < outer_dims_product; ++batch) { + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } + + auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; + auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; + memcpy(dst, src, block_bytesize); + } + } + } + }); + return result_contig; +} + Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index) { dim = maybe_wrap_dim(dim, self.dim()); @@ -498,6 +579,11 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim return result; } + if (dim == 1 && result.is_contiguous()) { + // fast pass + return index_select_out_cpu_dim1_(result, self, index_contig); + } + auto selfSlice = self.select(dim, 0); auto resultSlice = result.select(dim, 0); auto selfSlice_data = selfSlice.data_ptr(); diff --git a/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp b/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp new file mode 100644 index 000000000000..b5ed77f6e400 --- /dev/null +++ b/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp @@ -0,0 +1,311 @@ +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +namespace { + +template +void cpu_adaptive_avg_pool( + Tensor& output_, + const Tensor& input_, + IntArrayRef output_size) { + auto input = input_.contiguous(); + auto output = output_.contiguous(); + + auto input_data = input.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t ndim = input.ndimension(); + // treat batch size and channels as one dimension + int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1); + int64_t input_height = input.size(-2); + int64_t input_width = input.size(-1); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + // parallel on dim of N, C + at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) { + for (int64_t c = begin; c < end; c++) { + scalar_t* input_ptr = input_data + c * input_height * input_width; + scalar_t* output_ptr = output_data + c * output_height * output_width; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + // compute local average + scalar_t sum = 0; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / kh / kw; + } + } + } + }); + + if (!output_.is_contiguous()) { + output_.copy_(output); + } +} + +template +void cpu_adaptive_avg_pool_channels_last( + Tensor& output_, + const Tensor& input_, + IntArrayRef output_size) { + auto memory_format = at::MemoryFormat::ChannelsLast; + auto input = input_.contiguous(memory_format); + auto output = output_.contiguous(memory_format); + + auto input_data = input.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t nbatch = input.size(0); + int64_t channels = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + using Vec = vec256::Vec256; + // parallel on dim N, H, W + at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t oh = 0; + int64_t ow = 0; + data_index_init(begin, n, nbatch, oh, output_height, ow, output_width); + + for (int64_t i = begin; i < end; i++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t* out = output_data + i * channels; + int64_t size = channels; + + // Note: For oridinary usage scenario, each out lane should + // fit in L1 cache; otherwise consider block dim C. + // Pass I: zero the out lane + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec out_vec = Vec(scalar_t(0)); + out_vec.store(out + d1); + } + for (; d1 < size; d1++) { + out[d1] = scalar_t(0); + } + // Pass II: compute local sum + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + scalar_t* in = input_data + n * input_height * input_width * channels + + ih * input_width * channels + iw * channels; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2); + out_vec.store(out + d2); + } + for (; d2 < size; d2++) { + out[d2] += in[d2]; + } + } + } + // Pass III: compute local average + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kh * kw)); + out_vec.store(out + d3); + } + for (; d3 < size; d3++) { + out[d3] = out[d3] / kh / kw; + } + + // move on to next output index + data_index_step(n, nbatch, oh, output_height, ow, output_width); + } + }); + + if (!output_.is_contiguous(memory_format)) { + output_.copy_(output); + } +} + +template +void cpu_adaptive_avg_pool_backward( + Tensor& grad_input_, + const Tensor& grad_output_) { + auto grad_output = grad_output_.contiguous(); + auto grad_input = grad_input_.contiguous(); + + auto grad_output_data = grad_output.data_ptr(); + auto grad_input_data = grad_input.data_ptr(); + + int64_t ndim = grad_output.ndimension(); + // treat batch size and channels as one dimension + int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + int64_t input_height = grad_input.size(-2); + int64_t input_width = grad_input.size(-1); + int64_t output_height = grad_output.size(-2); + int64_t output_width = grad_output.size(-1); + + // parallel on dim of N, C + at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) { + for (int64_t c = begin; c < end; c++) { + scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width; + scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t grad_delta = grad_output_ptr[oh * output_width + ow] / kh / kw; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + grad_input_ptr[ih * input_width + iw] += grad_delta; + } + } + } + } + } + }); + + if (!grad_input_.is_contiguous()) { + grad_input_.copy_(grad_input); + } +} + +template +void cpu_adaptive_avg_pool_backward_channels_last( + Tensor& grad_input_, + const Tensor& grad_output_) { + auto memory_format = at::MemoryFormat::ChannelsLast; + auto grad_input = grad_input_.contiguous(memory_format); + auto grad_output = grad_output_.contiguous(memory_format); + + auto grad_input_data = grad_input.data_ptr(); + auto grad_output_data = grad_output.data_ptr(); + + int64_t nbatch = grad_input.size(0); + int64_t channels = grad_input.size(1); + int64_t input_height = grad_input.size(2); + int64_t input_width = grad_input.size(3); + int64_t output_height = grad_output.size(2); + int64_t output_width = grad_output.size(3); + + using Vec = vec256::Vec256; + // parallel on dim N + at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) { + for (int64_t n = begin; n < end; n++) { + scalar_t* grad_input_ptr = grad_input_data + n * input_height * input_width * channels; + scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t* gout = grad_output_ptr + oh * output_width * channels + ow * channels; + int64_t size = channels; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + scalar_t* gin = grad_input_ptr + ih * input_width * channels + iw * channels; + + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kh * kw)); + gin_vec.store(gin + d); + } + for (; d < size; d++) { + gin[d] += gout[d] / kw / kw; + } + } + } + } + } + } + }); + + if (!grad_input_.is_contiguous(memory_format)) { + grad_input_.copy_(grad_input); + } +} + +void adaptive_avg_pool2d_kernel_impl( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + switch (input.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d", [&] { + cpu_adaptive_avg_pool(output, input, output_size); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{ + cpu_adaptive_avg_pool_channels_last(output, input, output_size); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void adapative_avg_pool2d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output) { + switch (grad_output.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "adaptive_avg_pool2d_backward", [&] { + cpu_adaptive_avg_pool_backward(grad_input, grad_output); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "adaptive_avg_pool2d_backward_channels_last", [&]{ + cpu_adaptive_avg_pool_backward_channels_last(grad_input, grad_output); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +} // anonymous namespace + +REGISTER_DISPATCH(adaptive_avg_pool2d_kernel, &adaptive_avg_pool2d_kernel_impl); +REGISTER_DISPATCH(adaptive_avg_pool2d_backward_kernel, &adapative_avg_pool2d_backward_kernel_impl); + +}} // at::native diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index bfb136776333..f7c4f9c34613 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -277,7 +277,7 @@ static void sign_kernel(TensorIterator& iter){ [=](scalar_t a) -> scalar_t { return (0 < a) - (a < 0); }, [=](Vec256 self_vec){ - // Comparision operators returns bitmask. + // Comparison operators returns bitmask. auto left = Vec256::blendv(zero_vec, one_vec, zero_vec < self_vec); auto right = Vec256::blendv(zero_vec, one_vec, self_vec < zero_vec); diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index aa6d57cdd2df..61e7877761d8 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -4,36 +4,12 @@ #include #include #include +#include namespace at { namespace native { namespace { -template -inline T data_index_init(T offset) { - return offset; -} - -template -inline T data_index_init(T offset, T &x, const T &X, Args &&... args) { - offset = data_index_init(offset, std::forward(args)...); - x = offset % X; - return offset / X; -} - -inline bool data_index_step() { - return true; -} - -template -inline bool data_index_step(T &x, const T &X, Args &&... args) { - if (data_index_step(std::forward(args)...)) { - x = ((x + 1) == X) ? 0 : (x + 1); - return x == 0; - } - return false; -} - static inline int64_t nearest_idx( int64_t output_index, int64_t input_size, diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h new file mode 100644 index 000000000000..32d1de5adb51 --- /dev/null +++ b/aten/src/ATen/native/cpu/utils.h @@ -0,0 +1,30 @@ +#pragma once + +namespace at { namespace native { namespace { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T &x, const T &X, Args &&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T &x, const T &X, Args &&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +}}} // namespace at::native:: diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 2b81460c1a4b..d630d727019f 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -232,19 +232,17 @@ void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & va AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, value_.scalar_type(), "indexing_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "indexing_backward", [&] { - indexing_backward_kernel<<>>( - sorted_indices.data_ptr(), - orig_indices.data_ptr(), - value_.data_ptr(), - src_.data_ptr(), - num_indices, - sliceSize, - strideBefore, - nElemBefore); - }); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + indexing_backward_kernel<<>>( + sorted_indices.data_ptr(), + orig_indices.data_ptr(), + value_.data_ptr(), + src_.data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore); }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); if (permuted) self.copy_(src_.permute(inversePerm)); } @@ -508,77 +506,73 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const cuda::detail::canUse32BitIndexMath(source) && cuda::detail::canUse32BitIndexMath(index)) { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] { - cuda::detail::TensorInfo selfInfo = - cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); - selfInfo.reduceDim(selfAddDim); - AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { - auto sourceInfo = - cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); - sourceInfo.reduceDim(sourceAddDim); - - auto indexInfo = - cuda::detail::getTensorInfo(index); - indexInfo.collapseDims(); - - // A reasonable choice for when to have each thread iterate over - // index to choose - if (numIndex <= 16) { - if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); - } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); - } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + cuda::detail::TensorInfo selfInfo = + cuda::detail::getTensorInfo(self_); + int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + auto sourceInfo = + cuda::detail::getTensorInfo(source_); + int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + auto indexInfo = + cuda::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // index to choose + if (numIndex <= 16) { + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + } else { + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + } + } else { + bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); + + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); } else { - SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); } - } else { - bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); - - if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); - } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); - } else { - LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); - } - } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); - } else { - LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); - } + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); } else { - LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); } - }); + } }); }); } else { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] { - cuda::detail::TensorInfo selfInfo = - cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); - selfInfo.reduceDim(selfAddDim); - - cuda::detail::TensorInfo sourceInfo = - cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); - sourceInfo.reduceDim(sourceAddDim); - - AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { - cuda::detail::TensorInfo indexInfo = - cuda::detail::getTensorInfo(index); - indexInfo.collapseDims(); - - LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); - }); + cuda::detail::TensorInfo selfInfo = + cuda::detail::getTensorInfo(self_); + int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + + cuda::detail::TensorInfo sourceInfo = + cuda::detail::getTensorInfo(source_); + int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + cuda::detail::TensorInfo indexInfo = + cuda::detail::getTensorInfo(index); + indexInfo.collapseDims(); + + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); }); }); } @@ -839,17 +833,10 @@ Tensor& index_select_out_cuda(Tensor& out, const Tensor& self, int64_t dim, TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); -#if defined(__HIP_PLATFORM_HCC__) AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, out.scalar_type(), "index_select_cuda", [&] { index_select_out_cuda_impl(out, self, dim, index); }); -#else // __HIP_PLATFORM_HCC__ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::Bool, - out.scalar_type(), "index_select_cuda", - [&] { index_select_out_cuda_impl(out, self, dim, index); }); -#endif // __HIP_PLATFORM_HCC__ return out; } diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index 8265c5999376..051583a12a53 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -258,50 +258,24 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024 default: break; } @@ -333,53 +307,27 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024 default: break; } } } - diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index f70459928bf0..8437e80ebb48 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -23,6 +23,7 @@ static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *resu int64_t grid = std::min((size + warps_per_block - 1) / warps_per_block, 2048L); compute_cuda_kernel<<>>(repeat_ptr, cumsum_ptr, result_ptr, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } namespace at { namespace native { diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu index de807c8c5300..db3e853a9321 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cu +++ b/aten/src/ATen/native/cuda/SpectralOps.cu @@ -125,6 +125,7 @@ void _fft_fill_with_conjugate_symmetry_cuda_( static_cast(in_data), input_offset_calculator, output_offset_calculator); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -589,112 +590,5 @@ Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization return output; } -// cuFFT -// Currently not utilizing multi GPUs so this can be potentially sped up. -Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim, - bool complex_input, bool complex_output, bool inverse, - IntArrayRef checked_signal_sizes, int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - - CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(self.device().index()); - - Tensor input = self; - const auto fft_type = GetCuFFTTransformType(complex_input, complex_output); - - if (complex_input) { - TORCH_CHECK(input.size(-1) == 2, "Expected a complex (size 2) last dimension"); - } - - - // Slice when twosided complex-to-real. This is not always needed because we - // calculate the inembed. But it will benefit us in certain cases where we - // clone the input tensor. - // - // See NOTE [ cuFFT Embedded Strides ]. - // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h. - if (fft_type == CuFFTTransformType::C2R && !onesided) { - auto onesided_size = infer_ft_real_to_complex_onesided_size(checked_signal_sizes[signal_ndim - 1]); - input = input.narrow(signal_ndim, 0, onesided_size); - } - - // cuFFT requires input and output data pointers to complex type aligned. - // Our newly allocated output tensor is always 512 bytes aligned so it is fine - // (see kRoundSmall and kRoundLarge in THCCachingAllocator.cpp), but we do - // need to check input tensor to make sure that it is not unaligned, e.g., - // from a slicing. - bool must_clone = false; - auto complex_size_bytes = 2 * input.element_size(); - if (reinterpret_cast(input.data_ptr()) % complex_size_bytes != 0) { - must_clone = true; - } - - if (complex_input) { - auto strides = input.strides(); - // Real/imag dimension must be like complex type. - must_clone |= strides.back() != 1; - // Strides of other dimensions needs to be aligned when viewed as complex - // type, i.e., multiples of 2. - must_clone |= std::any_of(strides.begin(), strides.end() - 1, - [&](int64_t stride) { return stride % 2 != 0; }); - - // Complex to real FFTs may overwrite the input buffer (gh-34551) - must_clone |= !complex_output; - } - - if (must_clone) { - input = input.clone(MemoryFormat::Contiguous); - } - - // Now that we have done error check and data_ptr checks, we delegate all - // further cuFFT parameter computation and plan creation to the helper class - // CuFFTConfig in CuFFTPlanCache.h. - - // If plan caching is enabled, we check the cache. Note that this accesses - // plan_cache.max_size() and thus makes this function less functional. - // However, integrating additional arguments into the "public" level c++ APIs, - // e.g., irfft, is difficult as we have a long call sequence looking like - // irfft --> _fft --> _fft_with_size --dispatching-to-> _fft_cufft - - DimVector in_strides(signal_ndim + 1); - auto input_strides = input.strides(); - for (int64_t i = signal_ndim; i >= 0; --i) { - in_strides[i] = complex_input ? input_strides[i] / 2 : input_strides[i]; - } - - DimVector out_strides(signal_ndim + 1); - out_strides[signal_ndim] = 1; - if (fft_type == CuFFTTransformType::R2C && onesided) { - out_strides[signal_ndim - 1] = checked_signal_sizes[signal_ndim - 1] / 2 + 1; - } else { - out_strides[signal_ndim - 1] = checked_signal_sizes[signal_ndim - 1]; - } - for (int64_t i = signal_ndim - 2; i >= 0; --i) { - out_strides[i] = out_strides[i + 1] * checked_signal_sizes[i]; - } - - DimVector full_sizes(signal_ndim + 1); - full_sizes[0] = self.size(0); - std::copy(checked_signal_sizes.begin(), checked_signal_sizes.end(), full_sizes.begin() + 1); - CuFFTParams Params(in_strides, out_strides, full_sizes, fft_type, - c10::toValueType(input.scalar_type())); - - // This read is not locked for perf reason. Shouldn't matter too much because - // we check again after acquiring the lock. - if (plan_cache.max_size() > 0) { - std::lock_guard guard(plan_cache.mutex); - if (plan_cache.max_size() > 0) { // check again after acquiring the lock - const CuFFTConfig &config = plan_cache.lookup(Params); - return _run_cufft(config, input, signal_ndim, complex_input, - complex_output, inverse, checked_signal_sizes, - static_cast(normalization), - onesided, output_sizes, must_clone); - } - } - CuFFTConfig config(Params); - return _run_cufft(config, input, signal_ndim, complex_input, - complex_output, inverse, checked_signal_sizes, - static_cast(normalization), - onesided, output_sizes, must_clone); -} }} // at::native diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index a241f7df533c..effeef69f0cf 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -357,6 +357,7 @@ Tensor tril_indices_cuda( col, tril_size - rectangle_size, tril_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -434,6 +435,7 @@ Tensor triu_indices_cuda( col, rectangle_size, triu_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm index 8dd27aa1c3ed..9b1ff29feaa1 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm @@ -608,7 +608,9 @@ Tensor copy_to_host(const Tensor& input) { MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); auto&& sizes = [X sizes]; - MetalTensor mt{sizes}; + auto dummy = at::zeros(input.sizes()).contiguous(); + auto strides = dummy.strides(); + MetalTensor mt{sizes, strides.vec()}; mt.texture()->setCommandBuffer(commandBuffer); mt.texture()->allocateTextureStorage(sizes); MPSImage* Y = imageFromMetalTensor(mt); diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 9584fafcea4b..8fca9ad9ecdf 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -9,14 +9,6 @@ namespace at { namespace native { REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn); -Tensor _fft_mkl(const Tensor& input, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - AT_ERROR("fft: ATen not compiled with MKL support"); -} - Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { AT_ERROR("fft: ATen not compiled with MKL support"); } @@ -280,97 +272,6 @@ static DftiDescriptor _plan_mkl_fft( return descriptor; } -// MKL DFTI -Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - Tensor input = self; - bool need_contiguous = false; - // real/imag dimension must aligned when viewed as of complex type - if (complex_input) { - need_contiguous |= input.stride(-1) != 1; - for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) { - need_contiguous |= input.stride(i) % 2 != 0; - } - } - - // check if we can use MKL because MKL_LONG is 32bit on some OS, e.g. Windows - // need to check input and output size and strides - // be careful about complex domain, where the stride needs to be divided by 2 - // only need to test upper bound MKL_LONG_MAX as these values are non-negative - if (sizeof(MKL_LONG) < sizeof(int64_t)) { - int64_t inumel = 1 /* istride if we contiguous-fy */, onumel = 1; - int64_t isize, osize, istride, ostride; - for (int64_t i = signal_ndim; i >= 0; i--) { - isize = input.size(i); - osize = output_sizes[i]; - istride = complex_input ? input.stride(i) >> 1 : input.stride(i); - ostride = onumel; - TORCH_CHECK(isize <= MKL_LONG_MAX && osize <= MKL_LONG_MAX && ostride <= MKL_LONG_MAX, - "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]"); - if (!need_contiguous && istride > MKL_LONG_MAX) { - // If we didn't plan to contiguous-fy but the `istride` exceeds bound, - // check if we can stride (equal to `inumel`) get back within bound if - // we contiguous-fy. If so, then we need to always check `inumel` - // instead for the remaining iterations. The iterations before this are - // fine as `inumel` is non-decreasing. - need_contiguous = true; - } - TORCH_CHECK(!need_contiguous || inumel <= MKL_LONG_MAX, - "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]"); - inumel *= isize; - onumel *= osize; - } - } - - if (need_contiguous) { - input = input.contiguous(); - } - - - Tensor output = at::empty(output_sizes, input.options()); - - DimVector full_sizes(signal_ndim + 1); - full_sizes[0] = self.size(0); - std::copy(checked_signal_sizes.cbegin(), checked_signal_sizes.cend(), full_sizes.begin() + 1); - - // If "complex" is true, convert strides from complex viewed as real to complex strides. - // Otherwise, returns a copy of strides if "complex" is false. - auto convert_strides = [signal_ndim](IntArrayRef strides, bool complex) { - DimVector res(signal_ndim + 1); - if (complex) { - for (int64_t i = 0; i < res.size(); ++i) { - res[i] = strides[i] / 2; - } - } else { - res.assign(strides.cbegin(), strides.cend()); - } - return res; - }; - const auto in_strides = convert_strides(input.strides(), complex_input); - const auto out_strides = convert_strides(output.strides(), complex_output); - - auto descriptor = _plan_mkl_fft( - in_strides, out_strides, full_sizes, complex_input, complex_output, - normalization, !inverse, input.scalar_type()); - - if (inverse) { - MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), output.data_ptr())); - } else { - MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), output.data_ptr())); - } - // now if needed, fill out the other half using Hermitian symmetry dim - if (!complex_input && complex_output && !onesided) { - DimVector signal_dims(signal_ndim); - std::iota(signal_dims.begin(), signal_dims.end(), 1); - auto out_as_complex = at::view_as_complex(output); - at::native::_fft_fill_with_conjugate_symmetry_(out_as_complex, signal_dims); - } - return output; -} - // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes, IntArrayRef dim, int64_t normalization, bool forward) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3d217fad9851..768ddf2fc17d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2054,17 +2054,6 @@ dispatch: CPU, CUDA: native_group_norm_backward -- func: _fft_with_size(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, bool normalized, bool onesided, int[] output_sizes) -> Tensor - use_c10_dispatcher: full - variants: function - -- func: _fft_with_size.norm_modes(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, int normalization, bool onesided, int[] output_sizes) -> Tensor - use_c10_dispatcher: full - variants: function - dispatch: - CPU: _fft_mkl - CUDA: _fft_cufft - # Real to complex forward FFT - func: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor use_c10_dispatcher: full diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 5d25138500d7..660862181262 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -96,18 +96,16 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { dim3 block(C10_WARP_SIZE, SZ); AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, values.scalar_type(), "coalesce_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "coalesce_sparse_cuda", [&] { - using cuda_accscalar_t = acc_type; - apply::coalesceValuesKernel<<>>( - uniqueOffsets.data_ptr(), - origIndices.data_ptr(), - values.data_ptr(), - newValues.data_ptr(), - nnz, - newNnz, - stride - ); - }); + using cuda_accscalar_t = acc_type; + apply::coalesceValuesKernel<<>>( + uniqueOffsets.data_ptr(), + origIndices.data_ptr(), + values.data_ptr(), + newValues.data_ptr(), + nnz, + newNnz, + stride + ); }); } diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 81058ec266f2..d0aafe680efb 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -340,13 +340,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } else { TORCH_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); @@ -356,13 +354,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernel, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernel, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } } else { @@ -373,11 +369,9 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT // NB: Purposely not inplace! AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - values = values.mul(value); - } - }); + if (value.to() != static_cast(1)) { + values = values.mul(value); + } }); int64_t view_rows = 1; @@ -445,11 +439,9 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - s_values_ = s_values_.mul(value); - } - }); + if (value.to() != static_cast(1)) { + s_values_ = s_values_.mul(value); + } }); LongTensor r_indices_ = at::cat({t_indices_, s_indices_}, 1); Tensor r_values_ = at::cat({t_values_, s_values_}, 0); diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index bcd0fbc37e77..6b2e08576068 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -316,17 +316,6 @@ class TORCH_API RecordFunctionCallback { scopes_.fill(true); } - // This interface is for observers that do not pass an ObserverContext object - // between start and end callbacks. - explicit RecordFunctionCallback( - std::function start, - std::function end = - [](const RecordFunction&) {}): - start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }}, - end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} { - scopes_.fill(true); - } - RecordFunctionCallback& needsInputs(bool needs_inputs) { needs_inputs_ = needs_inputs; return *this; diff --git a/aten/src/TH/THAllocator.cpp b/aten/src/TH/THAllocator.cpp index 55b42d2f9d27..53b67a17032f 100644 --- a/aten/src/TH/THAllocator.cpp +++ b/aten/src/TH/THAllocator.cpp @@ -6,6 +6,7 @@ #endif #include +#include /* stuff for mapped files */ #ifdef _WIN32 @@ -74,24 +75,26 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, #ifdef _WIN32 if (flags_ & TH_ALLOCATOR_MAPPED_SHAREDMEM) { // Shadowing - const char *filename; - const char *eventname; + const wchar_t *filename; + const wchar_t *eventname; + const std::wstring wFilename = c10::u8u16(filename_); + const std::wstring wEventname = c10::u8u16(eventname_); LARGE_INTEGER hfilesz; if (filename_[0] == '/') { - filename = filename_.c_str() + 1; - eventname = eventname_.c_str() + 1; + filename = wFilename.c_str() + 1; + eventname = wEventname.c_str() + 1; } else { - filename = filename_.c_str(); - eventname = eventname_.c_str(); + filename = wFilename.c_str(); + eventname = wEventname.c_str(); } hfilesz.QuadPart = size; if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) { - event_ = CreateEvent(nullptr, FALSE, FALSE, eventname); + event_ = CreateEventW(nullptr, FALSE, FALSE, eventname); } else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) { - event_ = OpenEvent(EVENT_ALL_ACCESS, FALSE, eventname); + event_ = OpenEventW(EVENT_ALL_ACCESS, FALSE, eventname); } else { AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE"); } @@ -101,9 +104,9 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, } if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) { - handle_ = CreateFileMapping(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename); + handle_ = CreateFileMappingW(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename); } else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) { - handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, filename); + handle_ = OpenFileMappingW(FILE_MAP_ALL_ACCESS, FALSE, filename); } else { AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE"); } @@ -136,15 +139,21 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, AT_ERROR("TH_ALLOCATOR_MAPPED_FROMFD not supported on Windows"); } + // Shadowing + const wchar_t *filename; + const std::wstring wFilename = c10::u8u16(filename_); + + filename = wFilename.c_str(); + /* open file */ /* FILE_FLAG_RANDOM_ACCESS ? */ if (flags_) { - hfile = CreateFileA(filename_.c_str(), GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); + hfile = CreateFileW(filename, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); if (hfile == INVALID_HANDLE_VALUE) { AT_ERROR("could not open file <", filename_, "> in read-write mode; error code: <", GetLastError(), ">"); } } else { - hfile = CreateFileA(filename_.c_str(), GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + hfile = CreateFileW(filename, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); if (hfile == INVALID_HANDLE_VALUE) { AT_ERROR("could not open file <", filename_, "> in read-only mode; error code: <", GetLastError(), ">"); } @@ -181,11 +190,11 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, /* get map handle */ if (flags_) { - if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { + if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">"); } } else { - if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { + if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">"); } } diff --git a/aten/src/THC/THCApply.cuh b/aten/src/THC/THCApply.cuh index 368f1566e84c..7e52e1a1130c 100644 --- a/aten/src/THC/THCApply.cuh +++ b/aten/src/THC/THCApply.cuh @@ -6,6 +6,7 @@ #include #include #include +#include // // This file contains pointwise operation functions and kernels that @@ -242,14 +243,11 @@ bool THC_pointwiseApply1(THCState* state, // (or vice versa), the contiguous tensor can be collapsed to one // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A) \ - kernelPointwiseApply1 \ - <<>>( \ - OffsetInfo \ - (aInfo), \ - (TYPE) totalElements, op); +#define HANDLE_CASE(TYPE, A) \ + kernelPointwiseApply1 \ + <<>>( \ + OffsetInfo(aInfo), (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define HANDLE_A_CASE(TYPE, A) { \ switch (A) { \ @@ -298,6 +296,7 @@ bool THC_pointwiseApply1(THCState* state, uint64_t, 1> <<>>( aOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 @@ -310,6 +309,7 @@ bool THC_pointwiseApply1(THCState* state, uint64_t, -1> <<>>( aOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE @@ -392,16 +392,13 @@ bool THC_pointwiseApply2(THCState* state, // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. #define HANDLE_CASE(TYPE, A, B) \ - kernelPointwiseApply2 \ + kernelPointwiseApply2 \ <<>>( \ - OffsetInfo \ - (aInfo), \ - OffsetInfo \ - (bInfo), \ - (TYPE) totalElements, op); + OffsetInfo(aInfo), \ + OffsetInfo(bInfo), \ + (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + #define HANDLE_B_CASE(TYPE, A, B) { \ switch (B) { \ @@ -474,6 +471,7 @@ bool THC_pointwiseApply2(THCState* state, uint64_t, 1, 1> <<>>( aOffset, bOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); @@ -488,6 +486,7 @@ bool THC_pointwiseApply2(THCState* state, uint64_t, -1, -1> <<>>( aOffset, bOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE @@ -598,7 +597,8 @@ bool THC_pointwiseApply3(THCState* state, (bInfo), \ OffsetInfo \ (cInfo), \ - (TYPE) totalElements, op); + (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define HANDLE_C_CASE(TYPE, A, B, C) { \ switch (C) { \ @@ -697,6 +697,7 @@ bool THC_pointwiseApply3(THCState* state, uint64_t, 1, 1, 1> <<>>( aOffset, bOffset, cOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); @@ -715,6 +716,7 @@ bool THC_pointwiseApply3(THCState* state, uint64_t, -1, -1, -1> <<>>( aOffset, bOffset, cOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE diff --git a/aten/src/THC/THCReduceAll.cuh b/aten/src/THC/THCReduceAll.cuh index 9546f85f61c9..af2e264e6528 100644 --- a/aten/src/THC/THCReduceAll.cuh +++ b/aten/src/THC/THCReduceAll.cuh @@ -10,6 +10,7 @@ // #include +#include #include #ifdef __HIP_PLATFORM_HCC__ @@ -209,6 +210,7 @@ void callReduceAll(THCState* state, <<>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, (AccT*) scratchSpace); + C10_CUDA_KERNEL_LAUNCH_CHECK(); int numPass1Blocks = grid.x; getPass2ReduceBlockGrid(state, totalElements, grid, block); @@ -218,6 +220,7 @@ void callReduceAll(THCState* state, <<>>( numPass1Blocks, init, reduceOp, (AccT*) scratchSpace, devOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); THCudaFree(state, scratchSpace); } else { @@ -227,6 +230,7 @@ void callReduceAll(THCState* state, kernelReduceAll <<>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, devOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } diff --git a/aten/src/THC/THCTensorSort.cu b/aten/src/THC/THCTensorSort.cu index 8969209a1bdc..189e73b909fb 100644 --- a/aten/src/THC/THCTensorSort.cu +++ b/aten/src/THC/THCTensorSort.cu @@ -1,5 +1,6 @@ #include #include +#include void THCudaLongTensor_fillSliceWithIndex(THCState* state, THCudaLongTensor* t, @@ -28,8 +29,10 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, #define FILL_INDEX(T, DIM) \ fillSliceWithIndex \ - <<>>( \ - info, numSlices, sliceSize, info.strides[collapseDim]) + <<>>( \ + info, numSlices, sliceSize, info.strides[collapseDim]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() + if (THCTensor_canUse32BitIndexMath(state, t)) { TensorInfo info = @@ -59,6 +62,5 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, } #undef FILL_INDEX - THCudaCheck(cudaGetLastError()); } } diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu index 66ad275787f5..3f506d345714 100644 --- a/aten/src/THC/generic/THCTensorIndex.cu +++ b/aten/src/THC/generic/THCTensorIndex.cu @@ -4,6 +4,7 @@ #include #include +#include // Check tensor dimensions for index operations, and return the slice size. // src can be nullptr in case of indexFill: in that case it is ignored. @@ -127,11 +128,12 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ - indexCopySmallIndex \ - <<>>( \ - dstInfo, srcInfo, indicesInfo, \ - dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); +#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + indexCopySmallIndex \ + <<>>( \ + dstInfo, srcInfo, indicesInfo, \ + dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, TYPE, \ DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ @@ -141,7 +143,8 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT dstInfo, srcInfo, indicesInfo, \ dstCopyDim, srcCopyDim, srcTotalSize, \ (IDX_IS_MAJOR) ? sliceSize : numIndices, \ - dstCopyDimSize); + dstCopyDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -307,11 +310,12 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \ +#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \ indexFillSmallIndex \ - <<>>( \ - dstInfo, indicesInfo, \ - dstFillDim, sliceSize, dstFillDimSize, val); + <<>>( \ + dstInfo, indicesInfo, \ + dstFillDim, sliceSize, dstFillDimSize, val); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM, IDX_IS_MAJOR) \ indexFillLargeIndex \ @@ -319,7 +323,8 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT dstInfo, indicesInfo, \ dstFillDim, sliceSize * numIndices, \ (IDX_IS_MAJOR) ? sliceSize : numIndices, \ - dstFillDimSize, val); + dstFillDimSize, val); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu index 8c0dac0aa686..216a96443887 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.cu +++ b/aten/src/THC/generic/THCTensorMathMagma.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.cu" #else +#include + #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) #ifdef USE_MAGMA @@ -171,8 +173,10 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper dim3 threads(128); if (uplo == 'U') { THCTensor_(copyUpperSymmetric)<<>>(input_data, n, len); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { THCTensor_(copyLowerSymmetric)<<>>(input_data, n, len); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } THCTensor_(freeCopyTo)(state, input, ra_); diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu index 76f470ce7dfb..ce2f124215ca 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.cu +++ b/aten/src/THC/generic/THCTensorMathReduce.cu @@ -41,9 +41,11 @@ void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar dim3 threads(32); THCTensor_kernel_renorm - <<>> - (THCTensor_(data)(state, data), scalar_cast(value), size, scalar_cast(maxnorm)); + <<>>(THCTensor_(data)(state, data), + scalar_cast(value), size, scalar_cast(maxnorm)); + // Do not replace with C10_CUDA_KERNEL_LAUNCH_CHECK() yet as it exhibits different behaviour from THError(). + // THError() calls the an error handler, or throws std::runtime_error if a custom handler hasn't been registered. cudaError_t errcode = cudaGetLastError(); if(errcode != cudaSuccess) THError(cudaGetErrorString(errcode)); diff --git a/aten/src/THC/generic/THCTensorMode.cu b/aten/src/THC/generic/THCTensorMode.cu index 9fe955f3cf8d..8c428c9a5d1b 100644 --- a/aten/src/THC/generic/THCTensorMode.cu +++ b/aten/src/THC/generic/THCTensorMode.cu @@ -2,6 +2,7 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMode.cu" #else +#include #include void THCTensor_(calculateMode)(THCState *state, @@ -235,14 +236,14 @@ void THCTensor_(mode)(THCState *state, // Macro that calls kernel --> note that we set the block dimensions here, and // the amount of shared memory - #define HANDLE_MODE(SIZE) \ - { \ - dim3 blockSize(SIZE / 2); \ -\ - int memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ - computeMode \ - <<>>( \ - THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ + #define HANDLE_MODE(SIZE) \ + { \ + const dim3 blockSize(SIZE / 2); \ + const auto memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ + computeMode \ + <<>>( \ + THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } // Tradeoff between compilation time and the number of specializations. Ideally we would have diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu index f3ca8bf93b1b..1ef540ba3302 100644 --- a/aten/src/THC/generic/THCTensorRandom.cu +++ b/aten/src/THC/generic/THCTensorRandom.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) @@ -39,6 +40,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, larger_short), one, inputsize ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + at::Tensor smaller_short_wrapped = THTensor_wrap(smaller_short); at::Tensor smaller_wrapped = THTensor_wrap(smaller); at::Tensor larger_short_wrapped = THTensor_wrap(larger_short); @@ -57,6 +60,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, larger_short), inputsize - h_large_c, h_large_c ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + scalar_t q_max = at::max(THTensor_wrap(_q)).item(); condDiv<<< inputBlockDim, BLOCK_SIZE, 0, c10::cuda::getCurrentCUDAStream()>>>( @@ -64,6 +69,7 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, _J), inputsize, q_max ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); THCudaLongTensor_free(state, smaller); THCudaLongTensor_free(state, larger); @@ -104,6 +110,8 @@ void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, T THCTensor_(data)(state, uniform), THCTensor_(data)(state, bernoulli) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + THCTensor_(free)(state, uniform); THCTensor_(free)(state, bernoulli); } diff --git a/aten/src/THC/generic/THCTensorScatterGather.cu b/aten/src/THC/generic/THCTensorScatterGather.cu index 832539d370ce..a1ab8d63f163 100644 --- a/aten/src/THC/generic/THCTensorScatterGather.cu +++ b/aten/src/THC/generic/THCTensorScatterGather.cu @@ -2,10 +2,13 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorScatterGather.cu" #else +#include + #define RUN(TYPE, DIMS, REAL) \ - THCudaTensor_gatherKernel \ - <<>>( \ - tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); + THCudaTensor_gatherKernel \ + <<>>( \ + tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index) { @@ -61,19 +64,15 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, switch (indexInfo.dims) { case 1: RUN(unsigned int, 1, scalar_t); - THCudaCheck(cudaGetLastError()); break; case 2: RUN(unsigned int, 2, scalar_t); - THCudaCheck(cudaGetLastError()); break; case 3: RUN(unsigned int, 3, scalar_t); - THCudaCheck(cudaGetLastError()); break; default: RUN(unsigned int, -1, scalar_t); - THCudaCheck(cudaGetLastError()); break; } } else { @@ -84,7 +83,6 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, TensorInfo indexInfo = getTensorInfo(state, index); RUN(uint64_t, -1, scalar_t); - THCudaCheck(cudaGetLastError()); } } diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu index b4da00a98b7f..e378fe03358e 100644 --- a/aten/src/THC/generic/THCTensorSort.cu +++ b/aten/src/THC/generic/THCTensorSort.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorSort.cu" #else +#include + // In alignment with default sort on a c++ map, this function // will permute key and value tensors identically, and // in such a way that the 'key' tensor is ordered numerically @@ -53,8 +55,9 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, dim3 block(blockSize); \ \ if (dir) { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace, TYPE, SIZE> \ + <<>>( \ keyInfo, \ keySlices, \ (TYPE) keySliceSize, \ @@ -62,16 +65,19 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ GTComp()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace, TYPE, SIZE> \ + <<>>( \ keyInfo, \ keySlices, \ (TYPE) keySliceSize, \ (TYPE) keyInfo.strides[collapseKeyDim], \ valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ - LTComp()); \ + LTComp()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } while (0) @@ -147,8 +153,6 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, #undef HANDLE_CASE #undef HANDLE_SORT_CASE #undef HANDLE_A_CASE - - THCudaCheck(cudaGetLastError()); } void THCTensor_(sortViaThrust)(THCState* state, diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index 357b3f2e22f3..8d7bf7701c04 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -3,6 +3,7 @@ #else #include +#include void THCTensor_(topk)(THCState* state, THCTensor *topK, @@ -37,8 +38,8 @@ void THCTensor_(topk)(THCState* state, // is provided to the kernel for the arguments. #define RUN_K(INDEX_T, DIM, DIR) \ - gatherTopK \ - <<>>( \ + gatherTopK \ + <<>>( \ inputInfo, \ static_cast(sliceSize), \ static_cast(k), \ @@ -50,7 +51,8 @@ void THCTensor_(topk)(THCState* state, static_cast(topKSlices), \ static_cast(topKInfo.strides[collapseTopKDim]), \ indicesInfo, \ - static_cast(indicesInfo.strides[collapseIndicesDim])) + static_cast(indicesInfo.strides[collapseIndicesDim])); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() #define RUN_DIR(INDEX_T, DIM) \ if (dir) { \ @@ -71,10 +73,10 @@ void THCTensor_(topk)(THCState* state, } #define RUN_T(INDEX_T) \ - TensorInfo inputInfo = \ - getTensorInfo(state, input); \ - TensorInfo topKInfo = \ - getTensorInfo(state, topK); \ + TensorInfo inputInfo = \ + getTensorInfo(state, input); \ + TensorInfo topKInfo = \ + getTensorInfo(state, topK); \ TensorInfo indicesInfo = \ getTensorInfo(state, indices); \ \ diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py index 4fb7fffb5a5d..b0534bd9722d 100644 --- a/benchmarks/operator_benchmark/benchmark_caffe2.py +++ b/benchmarks/operator_benchmark/benchmark_caffe2.py @@ -50,10 +50,15 @@ def tensor(self, shapes, dtype='float32', device='cpu'): Return: C2 tensor of dtype """ + return self.feed_tensor(benchmark_utils.numpy_random(dtype, *shapes), device) + + def feed_tensor(self, tensor, device='cpu'): + """ Similar to tensor, but can supply any data compatible with FeedBlob + """ blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index) dev = self._device_option(device) with core.DeviceScope(dev): - workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes)) + workspace.FeedBlob(blob_name, tensor) Caffe2BenchmarkBase.tensor_index += 1 return blob_name diff --git a/benchmarks/operator_benchmark/c2/batch_gather_test.py b/benchmarks/operator_benchmark/c2/batch_gather_test.py new file mode 100644 index 000000000000..ff3d84b99b2b --- /dev/null +++ b/benchmarks/operator_benchmark/c2/batch_gather_test.py @@ -0,0 +1,56 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core +import numpy + + +"""Microbenchmarks for element-wise BatchGather operator.""" + +# Configs for C2 BatherGather operator +batch_gather_configs_short = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [8, 8, 1], + [256, 512, 1], + [512, 512, 1], + [8, 8, 2], + [256, 512, 2], + [512, 512, 2], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"] +) + +batch_gather_configs_long = op_bench.cross_product_configs( + M=[128, 1024], + N=[128, 1024], + K=[1, 2], + device=['cpu', 'cuda'], + tags=["long"] +) + +class BatchGatherBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, M, N, K, device): + self.input_one = self.tensor([M, N, K], device=device) + max_val = N + numpy.random.seed((1 << 32) - 1) + index_dim = numpy.random.randint(0, N) + self.index = self.feed_tensor(numpy.random.randint(0, max_val, index_dim), device=device) + self.output = self.tensor([M, index_dim, K], device=device) + self.set_module_name("batch_gather") + + def forward(self): + op = core.CreateOperator("BatchGather", [self.input_one, self.index], self.output) + return op + + +op_bench_c2.generate_c2_test( + batch_gather_configs_long + batch_gather_configs_short, BatchGatherBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/index_select_test.py b/benchmarks/operator_benchmark/pt/index_select_test.py new file mode 100644 index 000000000000..8418edb2840b --- /dev/null +++ b/benchmarks/operator_benchmark/pt/index_select_test.py @@ -0,0 +1,57 @@ +import operator_benchmark as op_bench +import torch +import numpy + + +"""Microbenchmarks for index_select operator.""" + +# An example input from this configuration is M=4, N=4, dim=0. +index_select_configs_short = op_bench.config_list( + attr_names=["M", "N", "K", "dim"], + attrs=[ + [8, 8, 1, 1], + [256, 512, 1, 1], + [512, 512, 1, 1], + [8, 8, 2, 1], + [256, 512, 2, 1], + [512, 512, 2, 1], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"] +) + + +index_select_configs_long = op_bench.cross_product_configs( + M=[128, 1024], + N=[128, 1024], + K=[1, 2], + dim=[1], + device=['cpu', 'cuda'], + tags=["long"] +) + + +class IndexSelectBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, dim, device): + max_val = N + numpy.random.seed((1 << 32) - 1) + index_dim = numpy.random.randint(0, N) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device), + "dim" : dim, + "index" : torch.tensor(numpy.random.randint(0, max_val, index_dim), device=device), + } + self.set_module_name("index_select") + + def forward(self, input_one, dim, index): + return torch.index_select(input_one, dim, index) + + +op_bench.generate_pt_test(index_select_configs_short + index_select_configs_long, + IndexSelectBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index 53a8bd16f43d..d47cedada40f 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -19,11 +19,11 @@ const float kLowSamplingProb = 0.0001; void addTestCallback( double sampling_prob = 1.0, - std::function fn = - [](const at::RecordFunction&) {}) { + std::function(const at::RecordFunction&)> fn = + [](const at::RecordFunction&) { return nullptr; }) { auto cb = at::RecordFunctionCallback( std::move(fn), - [](const at::RecordFunction&) {}) + [](const at::RecordFunction&, at::ObserverContext*) {}) .needsInputs(false); if (sampling_prob < 1.0) { cb.samplingProb(sampling_prob); @@ -111,6 +111,7 @@ int main(int argc, char** argv) { kLowSamplingProb, [&](const at::RecordFunction& fn) { ++cb_count; + return nullptr; } ); diff --git a/c10/core/Device.h b/c10/core/Device.h index 7827119bb0ac..04cd711c37b2 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -93,9 +93,13 @@ struct C10_API Device final { DeviceType type_; DeviceIndex index_ = -1; void validate() { - TORCH_CHECK(index_ == -1 || index_ >= 0, + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(index_ == -1 || index_ >= 0, "Device index must be -1 or non-negative, got ", (int)index_); - TORCH_CHECK(!is_cpu() || index_ <= 0, + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_cpu() || index_ <= 0, "CPU device index must be -1 or zero, got ", (int)index_); } }; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index cd31c96c0d66..5deab2a09832 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1426,7 +1426,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually // exclusive see #24090 - refresh_contiguous(memory_format); + refresh_contiguous(); } bool is_strides_like_channels_last() const { @@ -1540,7 +1540,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * or strides. */ void refresh_contiguous() { - // NOTE: Make sure to keep the other overload in sync with this implementation! is_contiguous_ = compute_contiguous(); // Note: // Dim 0, 1, 2 will never be a channels last 2d/3d format @@ -1574,42 +1573,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } - /** - * Faster implementation of refresh_contiguous() that can be used if - * we know the current MemoryFormat. - */ - void refresh_contiguous(MemoryFormat memory_format) { - // NOTE: Make sure to keep the other overload in sync with this implementation! - is_contiguous_ = memory_format == MemoryFormat::Contiguous || compute_contiguous(); - switch (memory_format) { - case MemoryFormat::Contiguous: - is_channels_last_contiguous_ = false; - is_channels_last_contiguous_ = false; - is_channels_last_3d_contiguous_ = false; - is_channels_last_ = false; - is_channels_last_3d_ = false; - is_non_overlapping_and_dense_ = true; - break; - case MemoryFormat::ChannelsLast: - is_channels_last_contiguous_ = compute_channels_last_contiguous_2d(); - is_channels_last_ = true; - is_channels_last_3d_ = false; - is_channels_last_3d_contiguous_ = false; - is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense(); - break; - case MemoryFormat::ChannelsLast3d: - is_channels_last_contiguous_ = false; - is_channels_last_ = false; - is_channels_last_3d_ = true; - is_channels_last_3d_contiguous_ = compute_channels_last_contiguous_3d(); - is_non_overlapping_and_dense_ = is_contiguous_ || compute_non_overlapping_and_dense(); - break; - case MemoryFormat::Preserve: - // Is this case even possible? - refresh_contiguous(); - } - } - /** * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) * from one TensorImpl to another TensorImpl. diff --git a/c10/core/impl/LocalDispatchKeySet.cpp b/c10/core/impl/LocalDispatchKeySet.cpp index f984c40b39c0..358e6ef7e1f7 100644 --- a/c10/core/impl/LocalDispatchKeySet.cpp +++ b/c10/core/impl/LocalDispatchKeySet.cpp @@ -5,6 +5,10 @@ namespace c10 { namespace impl { +C10_DEFINE_bool(disable_variable_dispatch, false, "This flag forcibly disables the Variable code paths from executing, which currently breaks profiling in the process."); + +namespace { + /// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, /// thread_local is not supported. #ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY @@ -14,10 +18,26 @@ thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; #else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) -PODLocalDispatchKeySet raw_local_dispatch_key_set; +static PODLocalDispatchKeySet raw_local_dispatch_key_set; #endif +} // anonymous namespace + +LocalDispatchKeySet tls_local_dispatch_key_set() { + // Hack until variable performance is fixed + // + // ezyang: I'm pretty unhappy about this implementation, it looks wrong + // to me, as it seems to be performing a mutation on + // raw_local_dispatch_key_set. I can't conveniently test the correct + // version though... + if (FLAGS_disable_variable_dispatch) { + raw_local_dispatch_key_set.set_excluded( + raw_local_dispatch_key_set.excluded() | autograd_dispatch_keyset); + } + return raw_local_dispatch_key_set; +} + void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) { raw_local_dispatch_key_set = PODLocalDispatchKeySet { key_set.included_.raw_repr(), diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index 7039272babf6..5262b1d4d6c0 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -23,6 +23,8 @@ namespace c10 { namespace impl { +C10_DECLARE_bool(disable_variable_dispatch); + // POD version of LocalDispatchKeySet. Declared here just so that // we can put it in the guards. struct C10_API PODLocalDispatchKeySet { @@ -52,19 +54,7 @@ struct C10_API LocalDispatchKeySet { DispatchKeySet excluded_; }; -/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, -/// thread_local is not supported. -#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY - extern thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; -#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) - extern PODLocalDispatchKeySet raw_local_dispatch_key_set; -#endif - -inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() { - // Don't let people fiddle with the thread_local directly just - // because they include this header. - return raw_local_dispatch_key_set; -} +C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); // Internal, use ThreadLocalStateGuard C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set); diff --git a/c10/util/Unicode.h b/c10/util/Unicode.h new file mode 100644 index 000000000000..9cce93cc9b83 --- /dev/null +++ b/c10/util/Unicode.h @@ -0,0 +1,29 @@ +#pragma once + +#if defined(_WIN32) +#include +#include +#include +#endif + +namespace c10 { +#if defined(_WIN32) +inline std::wstring u8u16(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + int size_needed = MultiByteToWideChar( + CP_UTF8, 0, str.c_str(), static_cast(str.size()), NULL, 0); + TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + std::wstring wstr(size_needed, 0); + MultiByteToWideChar( + CP_UTF8, + 0, + str.c_str(), + static_cast(str.size()), + &wstr[0], + size_needed); + return wstr; +} +#endif +} diff --git a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py index c5aea77d7199..041dcce97dbf 100644 --- a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py @@ -418,6 +418,147 @@ def test_small_sls(self, seed): ) assert 0 + @given(seed=st.integers(0, 65535)) + @settings(deadline=datetime.timedelta(seconds=10)) + def test_sls_layernorm(self, seed): + np.random.seed(seed) + workspace.ResetWorkspace() + + n = 2 + DIM = 3 + data = 4 * (np.random.random_sample((n, DIM)) + 1).astype(np.float32) + + lengths = np.array([n], dtype=np.int32) + indices = np.array(range(n), dtype=np.int64) + weights = np.random.uniform(low=0.01, high=0.5, size=[n]).astype(np.float32) + + pred_net = caffe2_pb2.NetDef() + pred_net.name = "pred" + pred_net.external_input.extend( + ["quantized_data", "weights", "indices", "lengths"] + ) + pred_net.external_output.append("Y_norm") + pred_net.external_output.append("Y_mean") + pred_net.external_output.append("Y_std") + + pred_net.op.add().CopyFrom( + core.CreateOperator( + "SparseLengthsWeightedSumFused8BitRowwise", + ["quantized_data", "weights", "indices", "lengths"], + ["Y"], + ) + ) + + pred_net.op.add().CopyFrom( + core.CreateOperator( + "LayerNorm", + ["Y"], + ["Y_norm", "Y_mean", "Y_std"], + epsilon=1e-4, + ) + ) + + ref_net = caffe2_pb2.NetDef() + ref_net.name = "ref" + ref_net.external_input.extend( + ["quantized_data", "weights", "indices", "lengths"] + ) + ref_net.external_output.append("Y_norm") + ref_net.external_output.append("Y_mean") + ref_net.external_output.append("Y_std") + + ref_net.op.add().CopyFrom( + core.CreateOperator( + "SparseLengthsWeightedSumFused8BitRowwiseFakeFP16NNPI", + ["quantized_data", "weights", "indices", "lengths"], + ["Y"], + ) + ) + + ref_net.op.add().CopyFrom( + core.CreateOperator( + "LayerNormFakeFP16NNPI", + ["Y"], + ["Y_norm", "Y_mean", "Y_std"], + epsilon=1e-4, + axis=1, + elementwise_affine=False + ) + ) + + workspace.FeedBlob("data", data) + workspace.RunOperatorOnce( + core.CreateOperator( + "FloatToFused8BitRowwiseQuantized", ["data"], ["quantized_data"] + ) + ) + + quantized_data = workspace.FetchBlob("quantized_data") + + onnxified_net = onnxifi_caffe2_net( + pred_net, + {}, + max_batch_size=1, + max_seq_size=n, + debug=True, + adjust_batch=True, + use_onnx=False, + ) + print("before", pred_net) + print("after", onnxified_net) + workspace.FeedBlob("indices", indices) + workspace.FeedBlob("lengths", lengths) + workspace.FeedBlob("weights", weights) + + workspace.CreateNet(onnxified_net) + workspace.CreateNet(ref_net) + + workspace.RunNet(onnxified_net.name) + Y_glow = workspace.FetchBlob("Y_norm") + Y_mean_glow = workspace.FetchBlob("Y_mean") + Y_std_glow = workspace.FetchBlob("Y_std") + + workspace.RunNet(ref_net.name) + Y = workspace.FetchBlob("Y") + print("pre normalization", Y) + Y_ref = workspace.FetchBlob("Y_norm") + Y_mean_ref = workspace.FetchBlob("Y_mean") + Y_std_ref = workspace.FetchBlob("Y_std") + + # print(Y_ref, Y_glow) + # print(Y_ref.shape, Y_glow.shape) + + diff = np.abs(Y_ref - Y_glow) + max_err = np.max(diff, axis=1) + num_offenders = (max_err > 0).sum() + if num_offenders > 0: + np.set_printoptions(precision=12) + print( + "ref", + Y_ref.astype(np.float16).astype(np.float32), + "glow", + Y_glow.astype(np.float16).astype(np.float32), + ) + print_test_debug_info( + "slws_fused_8bit_rowwise_inv_scale", + { + "seed": seed, + "indices": indices, + "data": data, + "quantized_data": quantized_data, + "lengths": lengths, + "weights": weights, + "Y_norm_glow": Y_glow, + "Y_norm_ref": Y_ref, + "Y_mean_glow": Y_mean_glow, + "Y_std_glow": Y_std_glow, + "Y_mean_ref": Y_mean_ref, + "Y_std_ref": Y_std_ref, + "diff": diff, + "rowwise_diff": np.max(diff, axis=1), + }, + ) + assert 0 if __name__ == '__main__': diff --git a/docs/source/conf.py b/docs/source/conf.py index fe1e2260be72..610f6efa0840 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -161,7 +161,7 @@ # TODO: verify this works as expected release = 'master' -# Customized html_title here. +# Customized html_title here. # Default is " ".join(project, release, "documentation") if not set if RELEASE: # remove hash (start with 'a') from version number if any @@ -192,6 +192,9 @@ # Disable docstring inheritance autodoc_inherit_docstrings = False +# Disable displaying type annotations, these can be very verbose +autodoc_typehints = 'none' + # -- katex javascript in header # @@ -253,9 +256,9 @@ def setup(app): add_css(css_file) # From PyTorch 1.5, we now use autogenerated files to document classes and -# functions. This breaks older references since +# functions. This breaks older references since # https://docs.pytorch.org/torch.html#torch.flip -# moved to +# moved to # https://docs.pytorch.org/torch/generated/torchflip.html # which breaks older links from blog posts, stack overflow answers and more. # To mitigate that, we add an id="torch.flip" in an appropriated place @@ -278,7 +281,7 @@ def visit_reference(self, node): # to autogenerated content anchor = ref_anchor[1] txt = node.parent.astext() - if txt == anchor or txt == anchor.split('.')[-1]: + if txt == anchor or txt == anchor.split('.')[-1]: self.body.append('

'.format(ref_anchor[1])) return old_call(self, node) Klass.visit_reference = visit_reference diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index e155537d7b99..deb7a161e1d3 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -189,6 +189,7 @@ ("aten::rfft", datetime.date(2021, 1, 31)), ("aten::quantile", datetime.date(2021, 1, 31)), ("aten::nanquantile", datetime.date(2021, 1, 31)), + ("aten::_fft_with_size", datetime.date(2021, 1, 31)), ] def allow_listed(schema, allow_list): diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index ca4ba0fdb3da..10f36cc8e394 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -739,8 +739,8 @@ void checkScopeCallbacks() { std::string(fn.name().str()) == "test_user_scope") { found_user_scope = true; } - }, - [](const at::RecordFunction&) {})); + return nullptr; + })); bool bad_scope = false; auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) { @@ -752,9 +752,8 @@ void checkScopeCallbacks() { } else { bad_scope = true; } - return true; - }, - [](const at::RecordFunction&) {}) + return nullptr; + }) .scopes({scope})); }; @@ -813,8 +812,8 @@ TEST(RecordFunctionTest, Basic) { } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { ts_names.insert(fn.name().str()); } - }, - [](const RecordFunction&) {}) + return nullptr; + }) .needsInputs(true)); TracedTestInputs eager_inputs, jit_inputs; @@ -851,9 +850,8 @@ TEST(RecordFunctionTest, Basic) { if (std::string(fn.name().str()) == "test") { ++sampled_cb_ctr; } - return true; - }, - [](const RecordFunction&) {}) + return nullptr; + }) .samplingProb(sampling_prob)); }; @@ -863,9 +861,8 @@ TEST(RecordFunctionTest, Basic) { if (std::string(fn.name().str()) == "test") { ++non_sampled_cb_ctr; } - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); auto handle = setup_sampled_callback(0.5); @@ -908,9 +905,8 @@ TEST(RecordFunctionTest, Basic) { [&fn_names, &mtx](const RecordFunction& fn) { std::lock_guard lock(mtx); fn_names.push_back(fn.name().str()); - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); { RecordFunctionGuard g1(false); { @@ -934,8 +930,10 @@ TEST(RecordFunctionTest, Basic) { std::vector ids; auto add_remove_test_add_cb = [&ids](size_t id) { return addGlobalCallback(RecordFunctionCallback( - [&ids, id](const RecordFunction& fn) { ids.push_back(id); }, - [](const RecordFunction&) {})); + [&ids, id](const RecordFunction& fn) { + ids.push_back(id); + return nullptr ; + })); }; auto h1 = add_remove_test_add_cb(1); @@ -972,8 +970,7 @@ TEST(RecordFunctionTest, Basic) { ids.clear(); addGlobalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(1); }, - [](const RecordFunction&) {})); + [&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; })); { RECORD_USER_SCOPE("test"); } @@ -983,8 +980,7 @@ TEST(RecordFunctionTest, Basic) { auto th = std::thread([&ids]() { addThreadLocalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(2); }, - [](const RecordFunction&) {})); + [&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; })); { RECORD_USER_SCOPE("test_thread"); } }); @@ -1070,8 +1066,7 @@ TEST(RecordFunctionTest, Basic) { bool ran = false; should_run = false; addGlobalCallback(RecordFunctionCallback( - [&ran](const RecordFunction& fn) { ran = true; }, - [](const RecordFunction&) {}) + [&ran](const RecordFunction& fn) { ran = true; return nullptr; }) .setShouldRun(shouldRunCallback)); { RECORD_USER_SCOPE("test"); } @@ -1093,8 +1088,8 @@ TEST(RecordFunctionTest, Basic) { auto handle = addThreadLocalCallback(RecordFunctionCallback( [&recorded_op](const RecordFunction& fn) { recorded_op = fn.name().str(); - }, - [](const RecordFunction&) {})); + return nullptr; + })); ThreadLocalState state; std::thread t_child([state]() { ThreadLocalStateGuard g_tls(state); @@ -1111,16 +1106,20 @@ TEST(RecordFunctionTest, Basic) { bool has_ids = false; addGlobalCallback( RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; }, - [](const RecordFunction&) {}) + [&has_ids](const RecordFunction& fn) { + has_ids = fn.handle() > 0; + return nullptr; + }) .needsIds(true)); { RECORD_USER_SCOPE("test"); } TORCH_CHECK(has_ids); clearCallbacks(); has_ids = false; addGlobalCallback(RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; }, - [](const RecordFunction&) {})); + [&has_ids](const RecordFunction& fn) { + has_ids = fn.handle() > 0; + return nullptr; + })); { RECORD_USER_SCOPE("test"); } TORCH_CHECK(!has_ids); clearCallbacks(); @@ -1138,6 +1137,7 @@ TEST(RecordFunctionTest, OperatorNameOverload) { } else { operator_names.insert("No Operator Name"); } + return nullptr; }) .scopes({at::RecordScope::FUNCTION})); auto t = torch::randn({1, 2, 3}, at::kCPU); @@ -1209,9 +1209,8 @@ TEST(ThreadLocalDebugInfoTest, Basic) { [&done](const RecordFunction&) { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); { c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); auto t = torch::randn({1, 2, 3}, at::kCPU); diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 3b25be6e49c1..4b3e962d835f 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -3268,6 +3268,20 @@ def forward(self, x): loss = criterion(output, target) loss.backward() + @requires_nccl() + @skip_if_not_multigpu + def test_pass_default_pg(self): + dist.init_process_group( + "nccl", + init_method=f"file://{self.file_name}", + world_size=self.world_size, + rank=self.rank, + ) + + default_pg = c10d.distributed_c10d._get_default_group() + dist.destroy_process_group(default_pg) + self.assertFalse(dist.is_initialized()) + @requires_nccl() @skip_if_not_multigpu def test_save_load_checkpoint(self): diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 67a66be19d84..b057d12a285d 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -4401,6 +4401,22 @@ def test_cat_transform_non_uniform(self): t2.log_abs_det_jacobian(x2, y2)], dim=dim) self.assertEqual(actual_jac, expected_jac) + def test_cat_event_dim(self): + t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) + t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) + dim = 1 + bs = 16 + x1 = torch.randn(bs, 2) + x2 = torch.randn(bs, 2) + x = torch.cat([x1, x2], dim=1) + t = CatTransform([t1, t2], dim=dim, lengths=[2, 2]) + y1 = t1(x1) + y2 = t2(x2) + y = t(x) + actual_jac = t.log_abs_det_jacobian(x, y) + expected_jac = sum([t1.log_abs_det_jacobian(x1, y1), + t2.log_abs_det_jacobian(x2, y2)]) + def test_stack_transform(self): x1 = -1 * torch.arange(1, 101, dtype=torch.float) x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100 diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py new file mode 100644 index 000000000000..db06663285da --- /dev/null +++ b/test/fx/test_fx_const_fold.py @@ -0,0 +1,274 @@ +import unittest + +import torch +from torch.fx.experimental import const_fold + + +class TestConstFold(unittest.TestCase): + def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule): + self.assertTrue(mod_folded.const_subgraph_module is not None) + + # Check that the constants are attributes in the main subgraph. + num_folded_attrs = 0 + for node in mod_folded.graph.nodes: + if node.op == "get_attr" and (node.target in mod_folded.const_output_names): + num_folded_attrs += 1 + self.assertEqual(num_folded_attrs, len(mod_folded.const_output_names)) + + def test_const_fold_basic_one_attr_no_name_collision(self): + r""" + Perform constant folding conversion, from original mod to split constant folding + module with two split subgraphs, where there's a single attr to fold and + a single output attr result to replace. + + attr1 attr1 + | | | | + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul attr2 x / (input from previous subgraph + \ / \ / is attr) + add sub y + | \ / + output mul attr2 + \ / + add + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]])) + + def forward(self, x, y): + a = self.attr_1 + self.attr_1 + x = x - a + return x * y + self.attr_2 + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_one_attr_name_collision(self): + r""" + Perform constant folding conversion, from original mod to split constant folding + module with two split subgraphs, where there's a single attr to fold and + a single output attr result to replace. Name the attrs such that they will + collide by name with folded attrs. + + add_1 add_1 + | | | | + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul add_2 x / (input from previous subgraph + \ / \ / is attr) + add sub y + | \ / + output mul add_2 + \ / + add + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Note: Named as such to result in name collision. + self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]])) + self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]])) + + def forward(self, x, y): + a = self.add_1__CF + self.add_1__CF + x = x - a + return x * y + self.add_2__CF + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_noop(self): + r""" + Check that a graph with no constant folding is handled correctly. + + x attr1 + \ / + sub + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + + def forward(self, x): + return x - self.attr1 + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + + # Check that the folded graph module is None, since there was no folding to do. + self.assertTrue(mod_folded.const_subgraph_module is None) + + # Now run both folded and non-folded to check results equal. + in_x = torch.tensor([[-0.45]]) + base_result = mod(in_x) + fold_result = mod_folded(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_two_attr_three_input(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into a single output, and there are three placeholder inputs. + + attr1 attr2 attr1 attr2 + \ / \ / + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul z x / (input from previous subgraph + \ / \ / is attr) + div sub y + | \ / + output mul z + \ / + div + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]])) + + def forward(self, x, y, z): + a = self.attr1 + self.attr1 + sub = x - a + mul = sub * y + return mul / z + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y, in_z = ( + torch.tensor([[-0.45]]), + torch.tensor([0.9]), + torch.tensor([1.1]), + ) + base_result = mod(in_x, in_y, in_z) + fold_result = mod_folded(in_x, in_y, in_z) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_two_attr(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into a single output. + + attr1 attr2 attr1 attr2 + \ / \ / + x add add (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + sub x | (input from previous subgraph is attr) + | \ / + output sub + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.randn(2, 3)) + self.attr2 = torch.nn.Parameter(torch.randn(2, 3)) + + def forward(self, x): + y = self.attr1 + self.attr2 + return x + y + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x = torch.randn(2, 3) + fold_result = mod_folded(in_x) + base_result = mod(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_multi_const_folded_attrs(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into two new attrs. + + attr1 attr2 attr1 attr2 + / \ | / \ | + permute | sum permute | sum + \ / / \ / | + x add y / add | + \ / \ / | | + sub add output output (become attrs add_1 and mul_1) + \ / ==> --------+-------+------ (const/base subgraph split) + \ / x | y | (inputs from previous subgraph + add \ / \ / are attrs) + | sub add + linear \ / + | add + sigmoid | + | linear + output | + sigmoid + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.randn(4, 4)) + self.attr2 = torch.nn.Parameter(torch.randn(4, 4)) + self.lin = torch.nn.Linear(4, 4) + + def forward(self, x, y): + a = self.attr1 + self.attr1.permute(1, 0) + x = x - a + amax = torch.sum(self.attr2, dim=1) + y = y + amax + return torch.sigmoid(self.lin(x + y)) + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.randn(4, 4), torch.randn(4) + fold_result = mod_folded(in_x, in_y) + base_result = mod(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) diff --git a/test/test_cuda.py b/test/test_cuda.py index 6249c250ae2e..498d7e71620e 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2895,6 +2895,209 @@ def test_max_large_axis(self): def test_to_numpy(self): self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_capture_simple(self): + s1 = torch.cuda.Stream() + + with torch.cuda.stream(s1): + a = torch.zeros((1000,), device="cuda") + a += 1 + g = torch.cuda._Graph() + g.capture_begin() + a += 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s1) + + g.replay() + g.replay() + + self.assertTrue(a.sum().item() == 3000.) + + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_rng_functional(self): + # The caching allocator isn't yet graph-safe. + # In this test, graphed regions try to ensure allocator safety by + # stashing references to all temporaries. This is why we use _fused_dropout + # instead of a public dropout API: _fused_dropout returns the mask temporary + # as well as the output, so we can stash references to both. + # + # TODO: + # Switch to public dropout API when the allocator is made graph-safe. + ops_with_kwargs = ((torch._fused_dropout, {"p": 0.1}), + (torch.nn.functional.rrelu, {"training": True}),) + size = 10000 + + def run(op, kwargs): + a = torch.randn((size,), device="cuda", dtype=torch.float) + + torch.cuda.manual_seed(5) + + # Control + eager_out = a + for _ in range(6): + out = op(eager_out, **kwargs) + # _fused_dropout returns a tuple, rrelu returns a bare tensor. + eager_out = out[0] if isinstance(out, tuple) else out + + graph_in = a.clone() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # warms up allocator so no mallocs occur in capture + refs = () + graph_out = graph_in + for _ in range(3): + out = op(graph_out, **kwargs) + refs += tuple(out) + graph_out = out[0] if isinstance(out, tuple) else out + del out, refs, graph_out + + torch.cuda.manual_seed(5) + + refs = () + g = torch.cuda._Graph() + g.capture_begin() + graph_out = graph_in + for _ in range(2): + out = op(graph_out, **kwargs) + refs += tuple(out) + graph_out = out[0] if isinstance(out, tuple) else out + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + # Runs a graphed->eager->graphed sequence of RNG ops. + # replay() plays 2 invocations of the op, so the sequence has 6 + # invocations total, matching Control. + # replay() reads from graph_in and writes to graph_out. + g.replay() + out = op(graph_out, **kwargs) + out = op(out[0], **kwargs)[0] if isinstance(out, tuple) else op(out, **kwargs) + graph_in.copy_(out) + g.replay() + + # If replay() updated RNG state correctly, graph_out + # should now hold data equal to eager_out. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op, kwargs in ops_with_kwargs: + run(op, kwargs) + + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_rng_distributions(self): + # The caching allocator isn't yet graph-safe. + # In this test, all ops maintain static references to inputs and outputs + # that persist across replay(), so they should be safe to test with graphs, + # EXCEPT for multinomial which is a complicated compound op. + # + # TODO: + # Uncomment multinomial when the allocator is made graph-safe. + size = 10000 + input = torch.rand((size,), device="cuda", dtype=torch.float) + alloc = torch.empty((size,), device="cuda", dtype=torch.float) + + # Torch ops to test with sample args (tuple) and kwargs (dict) + torch_with_args = (("bernoulli", (input.clone(),), {}), + # ("multinomial", (input.clone(), size, True), {}), + # ("multinomial", (input.clone(), size // 2, False), {}), + ("normal", (input.clone() + 1, input.clone()), {}), + ("poisson", (input.clone(),), {}), + ("rand", (size,), {"device": "cuda", "dtype": torch.float}), + ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), + ("randn", (size,), {"device": "cuda", "dtype": torch.float}),) + + # Tensor methods to test with sample args (tuple) + tensor_with_args = (("bernoulli_", (input.clone(),)), + ("cauchy_", ()), + ("exponential_", ()), + ("geometric_", (0.3,)), + ("log_normal_", ()), + ("normal_", ()), + ("random_", ()), + ("uniform_", ()),) + + def run(module, op, args, kwargs): + torch.cuda.manual_seed(5) + + # Each path runs a dummy op to increment the state a bit before creating controls. + if (module == "torch"): + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + control1 = alloc.clone() + control2 = alloc.clone() + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) + + g = torch.cuda._Graph() + if (module == "torch"): + g.capture_begin() + t1 = getattr(torch, op)(*args, **kwargs) + t2 = getattr(torch, op)(*args, **kwargs) + g.capture_end() + else: + t1 = alloc.clone() + t2 = alloc.clone() + g.capture_begin() + getattr(t1, op)(*args) + getattr(t2, op)(*args) + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + try: + self.assertNotEqual(control1, t1) + self.assertNotEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # Runs a dummy op prelude, as for controls, to make sure replay() + # picks up the dummy op's state increment. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + getattr(dummy, op)(*args) + + # Runs RNG ops that fill t1 and t2. + g.replay() + + try: + self.assertEqual(control1, t1) + self.assertEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op_with_args in torch_with_args: + run("torch", *op_with_args) + + for meth_with_args in tensor_with_args: + # Adds an empty dict for kwargs, which none of the Tensor methods use + run("Tensor", *(meth_with_args + ({},))) + class TestCudaComm(TestCase): def _test_broadcast(self, input): @@ -3279,5 +3482,6 @@ class TestNamedTupleInput_1(NamedTuple): self.assertEqual(expected_a, x.a) self.assertEqual(expected_b, x.b) + if __name__ == '__main__': run_tests() diff --git a/test/test_indexing.py b/test/test_indexing.py index f3430a158d89..b92fd94e8cbd 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -764,7 +764,7 @@ def test_int_indices(self, device): @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) @dtypesIfCPU(torch.float, torch.long, torch.bool, torch.bfloat16) - @dtypesIfCUDA(torch.half, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.long, torch.bool, torch.bfloat16) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 956a115e6d56..8b04418fa640 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -71,6 +71,19 @@ def setUp(self): torch._C._jit_set_texpr_fuser_enabled(True) self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + self.int_dtypes = [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bool, + ] + self.fp_dtypes = [ + torch.float16, + torch.float32, + torch.float64, + ] + self.dtypes = self.int_dtypes + self.fp_dtypes def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) @@ -461,21 +474,13 @@ def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.bool, - ] binary_ops = [ operator.__and__, operator.__or__, operator.__xor__ ] devices = self.devices - for dtype, op, device in product(dtypes, binary_ops, devices): + for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -500,20 +505,12 @@ def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.bool, - ] binary_ops = [ torch.min, torch.max ] devices = self.devices - for dtype, op, device in product(dtypes, binary_ops, devices): + for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1215,17 +1212,6 @@ def test_unary_ops(self): def apply(fn): return lambda x: fn(x) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] unary_ops = [ torch.lgamma, torch.sigmoid, @@ -1262,7 +1248,7 @@ def apply(fn): lambda x: torch.clamp(x, -10, 10), ] sizes = [(1,), (2,), (4, 4)] - for dtype, op, device, size in product(dtypes, unary_ops, self.devices, sizes): + for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): try: x = self.data_for(dtype, device, size=size) fn = apply(op) @@ -1286,18 +1272,7 @@ def test_binary_ops(self): def apply(fn): return lambda x, y: fn(x, y) - dtypes = [ - # FIXME: Fails in IR Eval: torch.int8 and_ cpu - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] + # FIXME: Fails in IR Eval: torch.int8 and_ cpu binary_ops = [ operator.__and__, operator.__or__, @@ -1329,7 +1304,7 @@ def apply(fn): torch.remainder, ] devices = self.devices - for dtype, op, device in product(dtypes, binary_ops, devices): + for dtype, op, device in product(self.dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1355,18 +1330,7 @@ def test_binary_tensor_scalar_ops(self): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - # FIXME: Fails in IR Eval: torch.int64 and_ cpu - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool - ] + # FIXME: Fails in IR Eval: torch.int64 and_ cpu binary_ops = [ operator.__and__, operator.__or__, @@ -1376,11 +1340,9 @@ def apply_with_scalar(fn, scalar): torch.mul, torch.eq, torch.ne, - - # FIXME: fails with dtype=uint8, scalar=-1 - # torch.ge, - # torch.lt, - # torch.gt, + torch.ge, + torch.lt, + torch.gt, # FIXME: segfaults on CPU backend # operator.__rshift__, @@ -1390,7 +1352,7 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] - for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1413,17 +1375,6 @@ def test_binary_div_ops(self): def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool - ] binary_ops = [ torch.div, torch.remainder, @@ -1433,7 +1384,7 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 - for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1457,7 +1408,6 @@ def apply_with_scalar(fn, scalar): dtypes = [ torch.int8, - torch.uint8, torch.int16, torch.int32, torch.int64, @@ -1498,23 +1448,12 @@ def test_ternary_ops(self): def apply(fn): return lambda x, y, z: fn(x, y, z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] ternary_ops = [ torch.lerp, torch.addcmul, ] devices = self.devices - for dtype, op, device in product(dtypes, ternary_ops, devices): + for dtype, op, device in product(self.dtypes, ternary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1540,22 +1479,11 @@ def test_list_ops(self): def apply(fn): return lambda x, y, z: fn([x * x, y * y, z * z]) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] devices = self.devices list_ops = [ torch.cat, ] - for dtype, op, device in product(dtypes, list_ops, devices): + for dtype, op, device in product(self.dtypes, list_ops, devices): try: x = self.data_for(dtype, device, size=[5, 4, 1, 7]) y = self.data_for(dtype, device, size=[5, 4, 1, 7]) @@ -1580,24 +1508,13 @@ def test_where_ops(self): def apply(fn): return lambda cond, x, y: fn(cond, x, y) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bool, - ] ops = [ torch.where, lambda cond, x, y: torch.where(cond, x, 3.1415), lambda cond, x, y: torch.where(cond, 42, y), ] devices = self.devices - for dtype, op, device in product(dtypes, ops, devices): + for dtype, op, device in product(self.dtypes, ops, devices): try: cond = self.data_for(torch.bool, device) x = self.data_for(dtype, device) @@ -1624,6 +1541,7 @@ def fn(x): return x * x + x unsupported_dtypes = [ + torch.uint8, torch.bfloat16, torch.complex32, torch.complex64, diff --git a/test/test_kernel_launch_checks.py b/test/test_kernel_launch_checks.py index 079a7182a1fc..698a5cda2a42 100644 --- a/test/test_kernel_launch_checks.py +++ b/test/test_kernel_launch_checks.py @@ -26,9 +26,16 @@ def test_check_code(self): """)) # Does it work for macros? - self.assertEqual(0, check_code_for_cuda_kernel_launches(""" -#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \\ + self.assertEqual(0, check_code_for_cuda_kernel_launches(r""" +#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \ C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + indexAddSmallIndex \ + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); """)) def test_check_cuda_launches(self): diff --git a/test/test_nn.py b/test/test_nn.py index 67412d54eed9..652b4d85cbed 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3587,49 +3587,57 @@ def test_adaptive_pooling_size_none(self): output = module(input) self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,)) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_adaptive_pooling_avg_nhwc(self): - input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32, device="cuda") - input = input.contiguous(memory_format=torch.channels_last).requires_grad_() - grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32, device="cuda") - pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + device_list = ['cpu'] + if TEST_CUDA: + device_list.append('cuda') - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + for device in device_list: + input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) + input = input.contiguous(memory_format=torch.channels_last).requires_grad_() + grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) + pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - out = pool(input) - out.backward(grad) - ref_out = ref_pool(ref_input) - ref_out.backward(ref_grad) + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(input.grad, ref_input.grad) + out = pool(input) + out.backward(grad) + ref_out = ref_pool(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_adaptive_pooling_avg_nhwc_non_contiguous(self): - input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32, device="cuda") - input = input.contiguous(memory_format=torch.channels_last) - input = input[:, ::2, :, :].requires_grad_() - grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32, device="cuda") - grad = grad[:, ::2, :, :] - pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + device_list = ['cpu'] + if TEST_CUDA: + device_list.append('cuda') - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + for device in device_list: + input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) + input = input.contiguous(memory_format=torch.channels_last) + input = input[:, ::2, :, :].requires_grad_() + grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) + grad = grad[:, ::2, :, :] + pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - out = pool(input) - out.backward(grad) - ref_out = ref_pool(ref_input) - ref_out.backward(ref_grad) + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(input.grad, ref_input.grad) + out = pool(input) + out.backward(grad) + ref_out = ref_pool(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @largeTensorTest('12GB', device='cuda') @@ -6479,16 +6487,6 @@ def test_RNN_change_dropout(self): self.assertNotEqual(output2.data, prev_output) prev_output = output1.data - def _verify_pixel_shuffle(self, input, output, upscale_factor): - for c in range(output.size(1)): - for h in range(output.size(2)): - for w in range(output.size(3)): - height_idx = h // upscale_factor - weight_idx = w // upscale_factor - channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ - (c * upscale_factor ** 2) - self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx]) - def test_inplace_thnn(self): modules = [nn.ReLU, nn.ELU, nn.SELU, nn.CELU, nn.RReLU] for mod in modules: @@ -6519,18 +6517,74 @@ def test_noncontig_conv_grad_cuda(self, dtype=torch.float): self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) def test_pixel_shuffle(self): - batch_size = random.randint(1, 3) - upscale_factor = random.randint(2, 5) - channels = random.randint(1, 4) * upscale_factor ** 2 - height = random.randint(5, 10) - width = random.randint(5, 10) - - input = torch.rand(batch_size, channels, height, width, requires_grad=True) - ps = nn.PixelShuffle(upscale_factor) - output = ps(input) - self._verify_pixel_shuffle(input.data, output.data, upscale_factor) - output.backward(output.data) - self.assertEqual(input.data, input.grad.data) + def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True): + # Function to imperatively ensure pixels are shuffled to the correct locations. + # Used to validate the batch operations in pixel_shuffle. + def _verify_pixel_shuffle(input, output, upscale_factor): + for c in range(output.size(-3)): + for h in range(output.size(-2)): + for w in range(output.size(-1)): + height_idx = h // upscale_factor + weight_idx = w // upscale_factor + channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ + (c * upscale_factor ** 2) + self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) + + upscale_factor = random.randint(2, 5) + # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. + channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) + height = random.randint(5, 10) + width = random.randint(5, 10) + + if num_input_dims == 1: + input = torch.rand(channels, requires_grad=True) + elif num_input_dims == 2: + input = torch.rand(height, width, requires_grad=True) + else: + batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] + input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) + ps = nn.PixelShuffle(upscale_factor) + + if num_input_dims >= 3 and valid_channels_dim: + output = ps(input) + _verify_pixel_shuffle(input, output, upscale_factor) + output.backward(output.data) + self.assertEqual(input.data, input.grad.data) + else: + self.assertRaises(RuntimeError, lambda: ps(input)) + + def test_pixel_shuffle_1D(): + _test_pixel_shuffle_helper(num_input_dims=1) + + def test_pixel_shuffle_2D(): + _test_pixel_shuffle_helper(num_input_dims=2) + + def test_pixel_shuffle_3D_with_valid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=3) + + def test_pixel_shuffle_4D_with_valid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=4) + + def test_pixel_shuffle_5D_with_valid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=5) + + def test_pixel_shuffle_3D_with_invalid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False) + + def test_pixel_shuffle_4D_with_invalid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False) + + def test_pixel_shuffle_5D_with_invalid_channels_dim(): + _test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False) + + test_pixel_shuffle_1D() + test_pixel_shuffle_2D() + test_pixel_shuffle_3D_with_valid_channels_dim() + test_pixel_shuffle_4D_with_valid_channels_dim() + test_pixel_shuffle_5D_with_valid_channels_dim() + test_pixel_shuffle_3D_with_invalid_channels_dim() + test_pixel_shuffle_4D_with_invalid_channels_dim() + test_pixel_shuffle_5D_with_invalid_channels_dim() def test_elu_inplace_view(self): v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True) diff --git a/test/test_sparse.py b/test/test_sparse.py index 72a67caa2038..5af630c0acb4 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -10,7 +10,7 @@ import random import unittest from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_WITH_ROCM, IS_WINDOWS + do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number from torch.autograd.gradcheck import gradcheck @@ -1301,7 +1301,6 @@ def test_spadd_hybrid(self): self._test_spadd_shape(10, [50, 30, 20], [2, 0]) @cuda_only - @unittest.skipIf(not TEST_WITH_ROCM, "runs only on ROCm") def test_sparse_add_out_bfloat16(self): # fp32 x, _, _ = self._gen_sparse(3, 5, 10) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index eada68c9ff92..6cdccf468326 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -420,11 +420,11 @@ def easy(x, y): traced = torch.jit.trace( easy, (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), - torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8)), + torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), ) a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) - b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8) + b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) x = warmup_and_run_forward(traced, a, b) self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) diff --git a/test/test_torch.py b/test/test_torch.py index d2566a90f382..b4d9ad6f23c0 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -19,15 +19,15 @@ from torch import multiprocessing as mp from torch.testing._internal.common_utils import ( TestCase, TEST_WITH_ROCM, run_tests, - IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, + IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, - skipIfRocm, skipIfNoSciPy, + skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, - skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, + skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, @@ -341,9 +341,6 @@ def test_device(self): self.assertEqual(90, cuda90.index) self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu:1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1)) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1)) self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 ')) self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2')) @@ -356,7 +353,6 @@ def test_device(self): self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3')) - self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1)) self.assertRaises(RuntimeError, lambda: torch.device(-1)) self.assertRaises(RuntimeError, lambda: torch.device('other')) @@ -1856,15 +1852,14 @@ def test_storage_casts(self): self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') self.assertIs(complexdouble_storage.dtype, torch.complex128) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.FloatStorage.from_file(f.name, True, size) + def assert_with_filename(filename): + size = 10000 + s1 = torch.FloatStorage.from_file(filename, True, size) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping - s2 = torch.FloatStorage.from_file(f.name, True, size) + s2 = torch.FloatStorage.from_file(filename, True, size) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, atol=0, rtol=0) @@ -1878,15 +1873,24 @@ def test_from_file(self): t2.fill_(rnum) self.assertEqual(t1, t2, atol=0, rtol=0) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") + # release the tensors + del s1, t1, s2, t2 + + with TemporaryFileName() as fname: + assert_with_filename(fname) + + if IS_FILESYSTEM_UTF8_ENCODING: + with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname: + assert_with_filename(fname) + def test_torch_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.from_file(f.name, True, size, dtype=torch.float) + def assert_with_filename(filename): + size = 10000 + s1 = torch.from_file(filename, True, size, dtype=torch.float) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping - s2 = torch.from_file(f.name, True, size, dtype=torch.float) + s2 = torch.from_file(filename, True, size, dtype=torch.float) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, atol=0, rtol=0) @@ -1900,6 +1904,16 @@ def test_torch_from_file(self): t2.fill_(rnum) self.assertEqual(t1, t2, atol=0, rtol=0) + # release the tensors + del s1, t1, s2, t2 + + with TemporaryFileName() as fname: + assert_with_filename(fname) + + if IS_FILESYSTEM_UTF8_ENCODING: + with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname: + assert_with_filename(fname) + def test_print(self): default_type = torch.Tensor().type() for t in torch._tensor_classes: @@ -6242,7 +6256,6 @@ class TestDevicePrecision(TestCase): exact_dtype = True @onlyCUDA - @skipCUDAIfNotRocm def test_index_add_bfloat16(self, device): inp_tensor = torch.randn(5, 3, device='cpu').bfloat16() t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu') diff --git a/third_party/tensorpipe b/third_party/tensorpipe index 82a114882e21..5381c57ba923 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit 82a114882e21b176916e2f12a7b566af3d63df71 +Subproject commit 5381c57ba923481ffaf7c40f9acc7f164ded887f diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index 66c7b1c7a1ab..45b99e64ec9a 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -93,7 +93,13 @@ TENSORPIPE_HEADERS = glob([ TENSORPIPE_BASE_SRCS = glob([ "tensorpipe/*.cc", "tensorpipe/channel/*.cc", - "tensorpipe/common/*.cc", + "tensorpipe/common/address.cc", + "tensorpipe/common/epoll_loop.cc", + "tensorpipe/common/error.cc", + "tensorpipe/common/fd.cc", + "tensorpipe/common/ibv.cc", + "tensorpipe/common/socket.cc", + "tensorpipe/common/system.cc", "tensorpipe/core/*.cc", "tensorpipe/transport/*.cc", "tensorpipe/util/*/*.cc", @@ -107,7 +113,10 @@ TENSORPIPE_SRCS = TENSORPIPE_BASE_SRCS + glob([ ]) TENSORPIPE_SRCS_CUDA = TENSORPIPE_SRCS + glob([ + "tensorpipe/common/cuda_loop.cc", + "tensorpipe/channel/cuda_basic/*.cc", "tensorpipe/channel/cuda_ipc/*.cc", + "tensorpipe/channel/cuda_xth/*.cc", ]) cc_library( diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b88596c2b609..8791dfa7b095 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1829,9 +1829,6 @@ grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, false, false, false, false, grad_input_mask) # fft -- name: _fft_with_size.norm_modes(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, int normalization, bool onesided, int[] output_sizes) -> Tensor - self: fft_backward(self, grad, signal_ndim, complex_input, complex_output, inverse, checked_signal_sizes, normalization, onesided, output_sizes) - - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back())) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 6c9ad0d5d6e1..eca10839ae88 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -472,6 +472,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/python_comm.cpp", "torch/csrc/cuda/Storage.cpp", "torch/csrc/cuda/Stream.cpp", + "torch/csrc/cuda/Graph.cpp", "torch/csrc/cuda/serialization.cpp", "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index a7f1f1b91c93..2a31552068a1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -735,6 +735,10 @@ class _CudaEventBase: def synchronize(self) -> None: ... def ipc_handle(self) -> bytes: ... +# Defined in torch/csrc/cuda/Graph.cpp +class _CudaGraphBase: + ... + # Defined in torch/csrc/DataLoader.cpp def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs diff --git a/torch/__init__.py b/torch/__init__.py index 403c192b47e9..30c328c1da6f 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -359,11 +359,13 @@ def set_deterministic(d): * :class:`torch.nn.FractionalMaxPool2d` when called on a CUDA tensor that requires grad * :class:`torch.nn.FractionalMaxPool3d` when called on a CUDA tensor that requires grad * :func:`torch.nn.functional.interpolate` when called on a CUDA tensor that requires grad - and one of the following modes is used: - - `linear` - - `bilinear` - - `bicubic` - - `trilinear` + and one of the following modes is used: + + - `linear` + - `bilinear` + - `bicubic` + - `trilinear` + * :class:`torch.nn.ReflectionPad1d` when called on a CUDA tensor that requires grad * :class:`torch.nn.ReflectionPad2d` when called on a CUDA tensor that requires grad * :class:`torch.nn.ReplicationPad1d` when called on a CUDA tensor that requires grad diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b23ab81ada93..3795b6e4f914 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -646,6 +646,7 @@ bool THCPComplexFloatStorage_init(PyObject *module); void THCPStream_init(PyObject *module); void THCPEvent_init(PyObject *module); +void THCPGraph_init(PyObject *module); #ifdef USE_CUDA PyMethodDef* THCPModule_methods(); @@ -786,6 +787,7 @@ PyObject* initModule() { THCPStream_init(module); THCPEvent_init(module); + THCPGraph_init(module); #endif auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) { diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 6da1a7e5e934..ed08e541661b 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2315,94 +2315,6 @@ std::tuple cholesky_solve_backward( return std::tuple{grad_self, grad_input2}; } -// Generally speaking, fft's backward is ifft. -Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - Tensor gI; - if (!complex_input && complex_output) { - // Forward is R2C - // Do inverse C2C and project onto real plane because grad can be - // asymmetrical so C2R can't be used. - if (onesided) { - // Forward is R2C (onesided) - // Think of onesided R2C rfft as - // 1. view as complex numbers (fill complex dim with zeros) - // 2. C2C fft - // 3. discard half of results - // So backward is - // 1. fill the other half with zeros (with `zero_grad_shape` below) - // (C2C ifft only take twosided inputs so we need to fill here) - // 2. inverse C2C ifft - // 3. discard the complex dim - int64_t zero_length = checked_signal_sizes[signal_ndim - 1] - grad.size(signal_ndim); - auto complex_full_grad = grad; - if (zero_length > 0) { - std::vector zero_grad_shape(signal_ndim + 2); - zero_grad_shape[0] = self.size(0); - for (int64_t i = 1; i < signal_ndim; i++) { - zero_grad_shape[i] = checked_signal_sizes[i - 1]; - } - zero_grad_shape[signal_ndim] = zero_length; - zero_grad_shape[signal_ndim + 1] = 2; - complex_full_grad = at::cat({ grad, at::zeros(zero_grad_shape, grad.options()) }, signal_ndim); - } - gI = _fft_with_size(complex_full_grad, signal_ndim, - /* complex_input */ true, /* complex_output */ true, - !inverse, checked_signal_sizes, normalization, - /* onesided */ false, complex_full_grad.sizes()).select(-1, 0); - } else { - gI = _fft_with_size(grad, signal_ndim, /* complex_input */ true, - /* complex_output */ true, !inverse, - checked_signal_sizes, normalization, - /* onesided */ false, grad.sizes()).select(-1, 0); - } - } else if (complex_input && !complex_output && onesided) { - // Forward is C2R (onesided) - // Think of onesided C2R irfft as - // 1. fill the other half by conjugate symmetry - // 2. inverse C2C ifft - // 3. discard the complex dimension - // So backward is - // 1. R2C rfft (essentially add dummy complex dimension, and dft) - // 2. accumulate gradient by conjugate symmetry - // since rfft results follow conjugate symmetry, we only need to - // double some entries from onesided rfft results, i.e., the ones with - // their reflected indices also landing out of the onesided range. So - // consider the index of last dim: - // i. idx = 0. - // Reflected to (N - 0) % N = 0. Not doubled. - // ii 0 < idx < floor(N/2) (last). - // N > N - idx > ceil(N/2) - // Reflected to () - // iii. idx = floor(N/2) = N/2 (last) when N even. - // Reflected to (N - N/2) % N = N/2. Not doubled. - // iv. idx = floor(N/2) = (N-1)/2 (last) when N odd. - // Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled. - // Therefore, needs to double - // idx = 1, 2, ..., N/2 - 1 when N even - // idx = 1, 2, ..., (N-1)/2 when N odd - // that is - // idx = 1, 2, ..., N - (floor(N/2) + 1) - // = 1, 2, ..., N - onesided_length - gI = _fft_with_size(grad, signal_ndim, /* complex_input */ false, - /* complex_output */ true, /* inverse */ false, - checked_signal_sizes, normalization, /* onesided */ true, - self.sizes()); - int64_t double_length = checked_signal_sizes[signal_ndim - 1] - self.size(signal_ndim); - if (double_length > 0) { // also covers case when signal size is zero - gI.narrow(signal_ndim, 1, double_length).mul_(2); - } - } else { - gI = _fft_with_size(grad, signal_ndim, complex_output, complex_input, - !inverse, checked_signal_sizes, normalization, onesided, - self.sizes()); - } - return gI; -} - Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) { // Forward is C2R (onesided) // Think of onesided C2R irfft as diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 78336ded0d88..488b7be9bd8a 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -173,8 +173,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { }); m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { auto cb = at::RecordFunctionCallback( - [](const at::RecordFunction&) {}, - [](const at::RecordFunction&) {}) + [](const at::RecordFunction&) { return nullptr; }, + [](const at::RecordFunction&, at::ObserverContext*) {}) .needsInputs(true) .samplingProb(sampling_prob); if (is_global) { diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 88cf22321865..eb52aec8920d 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -417,7 +417,7 @@ void pushProfilingCallbacksLegacy() { [](const at::RecordFunction& fn) { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { - return; + return nullptr; } bool record_cuda = state_ptr->config().state == ProfilerState::CUDA; @@ -432,8 +432,10 @@ void pushProfilingCallbacksLegacy() { } else { state_ptr->pushRange(fn, record_cuda, msg); } + + return nullptr; }, - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn, at::ObserverContext*) { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { return; diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp new file mode 100644 index 000000000000..b258f00bcf90 --- /dev/null +++ b/torch/csrc/cuda/Graph.cpp @@ -0,0 +1,46 @@ +#include + +#include + +#include +#include + +#include + +// Cargo culted partially from csrc/distributed/c10d/init.cpp +// and partially from csrc/cuda/Stream.cpp. +// THCPStream_init is also declared at global scope. + +// Because THCPGraph_init is forward declared in the only consumer (csrc/Module.cpp) +// I don't think we need a Graph.h. + +template +using shared_ptr_class_ = py::class_>; + +void THCPGraph_init(PyObject *module) { + // Pybind11 patch notes say "py::module_" is more up-to-date syntax, + // but CI linter and some builds prefer "module". + auto torch_C_m = py::handle(module).cast(); + + shared_ptr_class_<::at::cuda::CUDAGraph>(module, "_CudaGraphBase") + .def(py::init<>()) + .def("capture_begin", + &::at::cuda::CUDAGraph::capture_begin, + py::call_guard(), + R"(``capture_begin`` begins Cuda graph capture on the current stream.)") + .def("capture_end", + &::at::cuda::CUDAGraph::capture_end, + py::call_guard(), + R"(``capture_end`` ends Cuda graph capture on the current stream. + After ``capture_end``, ``replay`` may be called on this instance.)") + .def("replay", + &::at::cuda::CUDAGraph::replay, + py::call_guard(), + R"(``replay`` replays the Cuda graph captured by this instance.)") + // reset is called in __del__ on the Python side + // (see class Graph in torch/cuda/streams.py for reasons and caveats) + .def("reset", + &::at::cuda::CUDAGraph::reset, + py::call_guard(), + R"(``reset`` deletes the graph currently held by this instance.)"); +} diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index aaaaf6185dde..0d2c1c20b555 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -987,7 +987,7 @@ std::tuple InsertQuantDeQuantHelper:: v->debugName(), " exists."); QParamVector qparams; - c10::QScheme qscheme; + c10::QScheme qscheme = c10::kPerTensorAffine; auto observer_module = module.attr(observer_name.value()).toModule(); auto scalar_type = observer_module.attr("dtype"); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 6f587b910866..8a71e52db556 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -739,15 +739,29 @@ class TensorExprFuser { }; // clang-format on - // Value is only supported if operands are floats. - if (node->isMemberOf(float_only_operator_set)) { - for (const Value* v : node->inputs()) { - if (auto const& tt = v->type()->cast()) { - auto const& st = tt->scalarType(); - if (!st || !isFloatingType(*st)) { - return false; - } - } else if (!v->type()->cast()) { + for (const Value* v : node->inputs()) { + if (auto const& tt = v->type()->cast()) { + auto const& st = tt->scalarType(); + + // All tensors must be typed. + if (!st) { + return false; + } + + // Byte tensors introduce too many corner cases in type promotion. + // Better not to try to handle them. + if (*st == c10::ScalarType::Byte) { + return false; + } + + // These operators only support floats, because integer divisors need to + // raise ZeroDivisionError. + if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) { + return false; + } + } else if (node->isMemberOf(float_only_operator_set)) { + // Check scalar operands of float-only ops. + if (!v->type()->cast()) { return false; } } diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 0850b535fe30..a4068ac6d7f3 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -16,7 +16,7 @@ import threading from typing import List, Optional, Tuple, Union from ._utils import _get_device_index, _dummy_type -from .streams import Stream, Event +from .streams import Stream, Event, _Graph from .. import device as _device import torch._C diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 14345baf6abd..9c9c30a7ff29 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -8,6 +8,7 @@ # Define dummy base classes torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase') torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase') + torch._C.__dict__['_CudaGraphBase'] = _dummy_type('_CudaGraphBase') class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. @@ -20,7 +21,7 @@ class Stream(torch._C._CudaStreamBase): device(torch.device or int, optional): a device on which to allocate the stream. If :attr:`device` is ``None`` (default) or a negative integer, this will use the current device. - priority(int, optional): priority of the stream. Can be either + priority(int, optional): priority of the stream. Can be either -1 (high priority) or 0 (low priority). By default, streams have priority 0. @@ -201,3 +202,5 @@ def __repr__(self): return ''.format(self._as_parameter_.value) else: return '' + +_Graph = torch._C._CudaGraphBase diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index bbcef98d4214..751621189706 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -126,6 +126,7 @@ def powerSGD_hook( # Incorporate the error from the previous state into the gradients. bucket_index = bucket.get_index() + input_tensor_cp = None if state.use_error_feedback: # The buckets can be rebuilt during training. # In this case, the error tensor shape will not be aligned with the input tensor, @@ -189,9 +190,7 @@ def compute_q(fut): torch.matmul(matrix.t(), p, out=q) return [ - dist.all_reduce(q, group=group_to_use, async_op=True) - .get_future() - .wait()[0] + dist.all_reduce(q, group=group_to_use, async_op=True).get_future().wait()[0] ] def decompress(fut): @@ -201,6 +200,7 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. state.error_dict[bucket_index] = input_tensor_cp - input_tensor + assert not torch.any(torch.isnan(state.error_dict[bucket_index])) ret = input_tensor.resize_(total_length) return [ret] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 83260ec8dbdf..387da70403b0 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -147,7 +147,8 @@ def __getattribute__(self, key): class group(object): - WORLD = object() + # Points to the default PG once initialized. + WORLD: Optional[ProcessGroup] = None class GroupMember(object): @@ -166,7 +167,6 @@ class GroupMember(object): _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} # Default process group state -_default_pg: Optional[ProcessGroup] = None _default_pg_init_method = None # Process group count for default naming @@ -177,7 +177,7 @@ def _rank_not_in_group(group: ProcessGroup): """ Helper that checks if the current process's rank is not in a given group. """ - if group == GroupMember.WORLD: + if group is None: return False return group == GroupMember.NON_GROUP_MEMBER @@ -214,23 +214,12 @@ def _get_global_rank(group, group_rank): raise RuntimeError("The group rank is not part of the group") -def _check_default_pg() -> ProcessGroup: - """ - Helper that checks if the default ProcessGroup has been initialized, with - assertion. - """ - if _default_pg is not None: - return _default_pg - else: - raise RuntimeError("Default process group is not initialized") - - def _get_group_size(group): """ Helper that gets a given group's world size. """ - if group is GroupMember.WORLD: - default_pg = _check_default_pg() + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() return default_pg.size() if group not in _pg_group_ranks: raise RuntimeError("The given group does not exist") @@ -306,7 +295,7 @@ def is_initialized(): """ Checking if the default process group has been initialized """ - return _default_pg is not None + return GroupMember.WORLD is not None def _get_default_group(): @@ -316,7 +305,7 @@ def _get_default_group(): if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " "please make sure to call init_process_group.") - return _default_pg + return GroupMember.WORLD def _get_default_store(): @@ -326,12 +315,15 @@ def _get_default_store(): if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " "please make sure to call init_process_group.") - default_pg = _check_default_pg() + default_pg = _get_default_group() _, default_store = _pg_map[default_pg] return default_store +def _update_default_pg(pg): + GroupMember.WORLD = group.WORLD = pg + -def get_backend(group=group.WORLD): +def get_backend(group=None): """ Returns the backend of the given process group. @@ -344,8 +336,8 @@ def get_backend(group=group.WORLD): The backend of the given process group as a lower case string. """ - if group == GroupMember.WORLD: - pg = _check_default_pg() + if group is None: + pg = _get_default_group() else: pg = group if _rank_not_in_group(pg): @@ -421,14 +413,13 @@ def init_process_group(backend, """ global _pg_group_ranks global _backend - global _default_pg global _default_pg_init_method if not isinstance(timeout, timedelta): raise RuntimeError("Expected timeout argument to be of type" "datetime.timedelta") - if _default_pg is not None: + if GroupMember.WORLD is not None: raise RuntimeError("trying to initialize the default process group " "twice!") @@ -450,14 +441,14 @@ def init_process_group(backend, "are ignored since they are assigned by the " "MPI runtime.".format(world_size, rank)) - _default_pg = _new_process_group_helper( + _update_default_pg(_new_process_group_helper( -1, -1, [], Backend.MPI, None, group_name=group_name, - timeout=timeout) + timeout=timeout)) else: # backward compatible API if store is None: @@ -467,17 +458,17 @@ def init_process_group(backend, store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) - _default_pg = _new_process_group_helper( + _update_default_pg(_new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, - timeout=timeout) + timeout=timeout)) - _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())} - _backend = _pg_map[_default_pg][0] + _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore + _backend = _pg_map[GroupMember.WORLD][0] # type: ignore _default_pg_init_method = init_method # barrier at the end to ensure that once we return from this method, all @@ -537,7 +528,7 @@ def _new_process_group_helper(world_size, # If this is a subgroup (which means group_ranks is specified), # we check if the current process is a member of the new group. if not is_default_group: - global_rank = _check_default_pg().rank() + global_rank = _get_default_group().rank() if global_rank not in group_ranks: return GroupMember.NON_GROUP_MEMBER @@ -576,7 +567,7 @@ def _new_process_group_helper(world_size, return pg -def destroy_process_group(group=group.WORLD): +def destroy_process_group(group=None): """ Destroy a given process group, and deinitialize the distributed package @@ -589,15 +580,14 @@ def destroy_process_group(group=group.WORLD): global _pg_map global _pg_names global _pg_group_ranks - global _default_pg global _default_pg_init_method global _group_count if group == GroupMember.NON_GROUP_MEMBER: return - if group == GroupMember.WORLD: - pg = _default_pg + if group is None: + pg = GroupMember.WORLD else: pg = group @@ -605,8 +595,8 @@ def destroy_process_group(group=group.WORLD): if _pg_map.get(pg, None) is None: raise RuntimeError("Invalid process group specified") - if group == GroupMember.WORLD: - _default_pg = None + if group is None or group == GroupMember.WORLD: + _update_default_pg(None) _default_pg_init_method = None _pg_map.clear() _pg_names.clear() @@ -627,7 +617,7 @@ def destroy_process_group(group=group.WORLD): del _pg_group_ranks[pg] -def get_rank(group=group.WORLD): +def get_rank(group=None): """ Returns the rank of current process group @@ -636,7 +626,8 @@ def get_rank(group=group.WORLD): ``world_size``. Arguments: - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: The rank of the process group @@ -646,19 +637,20 @@ def get_rank(group=group.WORLD): if _rank_not_in_group(group): return -1 - default_pg = _check_default_pg() - if group == GroupMember.WORLD: + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: return default_pg.rank() return _get_group_rank(group, default_pg.rank()) -def get_world_size(group=group.WORLD): +def get_world_size(group=None): """ Returns the number of processes in the current process group Arguments: - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: The world size of the process group @@ -673,7 +665,7 @@ def get_world_size(group=group.WORLD): def isend(tensor, dst, - group=group.WORLD, + group=None, tag=0): """ Sends a tensor asynchronously. @@ -681,7 +673,8 @@ def isend(tensor, Arguments: tensor (Tensor): Tensor to send. dst (int): Destination rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with remote recv Returns: @@ -693,8 +686,8 @@ def isend(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() return default_pg.send([tensor], dst, tag) else: group_dst_rank = _get_group_rank(group, dst) @@ -703,7 +696,7 @@ def isend(tensor, def irecv(tensor, src, - group=group.WORLD, + group=None, tag=0): """ Receives a tensor asynchronously. @@ -711,7 +704,8 @@ def irecv(tensor, Arguments: tensor (Tensor): Tensor to fill with received data. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match recv with remote send Returns: @@ -723,8 +717,8 @@ def irecv(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() return default_pg.recv([tensor], src, tag) else: group_src_rank = _get_group_rank(group, src) @@ -733,7 +727,7 @@ def irecv(tensor, def send(tensor, dst, - group=group.WORLD, + group=None, tag=0): """ Sends a tensor synchronously. @@ -741,7 +735,8 @@ def send(tensor, Arguments: tensor (Tensor): Tensor to send. dst (int): Destination rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with remote recv """ @@ -749,8 +744,8 @@ def send(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() default_pg.send([tensor], dst, tag).wait() else: group_dst_rank = _get_group_rank(group, dst) @@ -759,7 +754,7 @@ def send(tensor, def recv(tensor, src=None, - group=group.WORLD, + group=None, tag=0): """ Receives a tensor synchronously. @@ -768,7 +763,8 @@ def recv(tensor, tensor (Tensor): Tensor to fill with received data. src (int, optional): Source rank. Will receive from any process if unspecified. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match recv with remote send Returns: @@ -780,8 +776,8 @@ def recv(tensor, if _rank_not_in_group(group): return -1 - if group == GroupMember.WORLD: - pg = _check_default_pg() + if group is None: + pg = _get_default_group() else: pg = group @@ -789,12 +785,12 @@ def recv(tensor, work = pg.recv_anysource([tensor], tag) work.wait() src_rank = work._source_rank() - if group == GroupMember.WORLD: + if group is None or group is GroupMember.WORLD: return src_rank else: return _get_global_rank(pg, src_rank) else: - if group == GroupMember.WORLD: + if group is None or group is GroupMember.WORLD: pg.recv([tensor], src, tag).wait() else: group_src_rank = _get_group_rank(pg, src) @@ -816,17 +812,18 @@ class P2POp(object): ``torch.distributed.irecv``. tensor (Tensor): Tensor to send or receive. peer (int): Destination or source rank. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with recv. """ - def __init__(self, op, tensor, peer, group=group.WORLD, tag=0): + def __init__(self, op, tensor, peer, group=None, tag=0): self.op = op self.tensor = tensor self.peer = peer self.group = group self.tag = tag - def __new__(cls, op, tensor, peer, group=group.WORLD, tag=0): + def __new__(cls, op, tensor, peer, group=None, tag=0): _check_op(op) _check_single_tensor(tensor, "tensor") return object.__new__(cls) @@ -896,7 +893,7 @@ def batch_isend_irecv(p2p_op_list): def broadcast_multigpu(tensor_list, src, - group=group.WORLD, + group=None, async_op=False, src_tensor=0): """ @@ -920,7 +917,8 @@ def broadcast_multigpu(tensor_list, for all the distributed processes calling this function. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op src_tensor (int, optional): Source tensor rank within ``tensor_list`` @@ -936,8 +934,8 @@ def broadcast_multigpu(tensor_list, opts.rootRank = src opts.rootTensor = src_tensor - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.broadcast(tensor_list, opts) else: group_src_rank = _get_group_rank(group, src) @@ -951,7 +949,7 @@ def broadcast_multigpu(tensor_list, def broadcast(tensor, src, - group=group.WORLD, + group=None, async_op=False): """ Broadcasts the tensor to the whole group. @@ -963,7 +961,8 @@ def broadcast(tensor, tensor (Tensor): Data to be sent if ``src`` is the rank of current process, and tensor to be used to save received data otherwise. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -979,8 +978,8 @@ def broadcast(tensor, opts.rootRank = src opts.rootTensor = 0 - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.broadcast([tensor], opts) else: group_src_rank = _get_group_rank(group, src) @@ -994,7 +993,7 @@ def broadcast(tensor, def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): r""" Reduces the tensor data across all machines in such a way that all get @@ -1020,7 +1019,8 @@ def all_reduce_multigpu(tensor_list, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1035,8 +1035,8 @@ def all_reduce_multigpu(tensor_list, opts = AllreduceOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allreduce(tensor_list, opts) else: work = group.allreduce(tensor_list, opts) @@ -1049,7 +1049,7 @@ def all_reduce_multigpu(tensor_list, def all_reduce(tensor, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces the tensor data across all machines in such a way that all get @@ -1065,7 +1065,8 @@ def all_reduce(tensor, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1107,8 +1108,8 @@ def all_reduce(tensor, opts = AllreduceOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allreduce([tensor], opts) else: work = group.allreduce([tensor], opts) @@ -1121,7 +1122,7 @@ def all_reduce(tensor, def all_reduce_coalesced(tensors, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ WARNING: at this time individual shape checking is not implemented across nodes. @@ -1146,7 +1147,8 @@ def all_reduce_coalesced(tensors, op (Optional[ReduceOp]): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (Optional[ProcessGroup]): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (Optional[bool]): Whether this op should be an async op. Returns: @@ -1165,8 +1167,8 @@ def all_reduce_coalesced(tensors, opts = AllreduceCoalescedOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allreduce_coalesced(tensors, opts) else: work = group.allreduce_coalesced(tensors, opts) @@ -1180,7 +1182,7 @@ def all_reduce_coalesced(tensors, def reduce_multigpu(tensor_list, dst, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False, dst_tensor=0): """ @@ -1202,7 +1204,8 @@ def reduce_multigpu(tensor_list, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op dst_tensor (int, optional): Destination tensor rank within ``tensor_list`` @@ -1220,8 +1223,8 @@ def reduce_multigpu(tensor_list, opts.rootRank = dst opts.rootTensor = dst_tensor - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.reduce(tensor_list, opts) else: group_dst_rank = _get_group_rank(group, dst) @@ -1237,7 +1240,7 @@ def reduce_multigpu(tensor_list, def reduce(tensor, dst, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces the tensor data across all machines. @@ -1251,7 +1254,8 @@ def reduce(tensor, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1267,8 +1271,8 @@ def reduce(tensor, opts.reduceOp = op opts.rootRank = dst - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.reduce([tensor], opts) else: group_dst_rank = _get_group_rank(group, dst) @@ -1283,7 +1287,7 @@ def reduce(tensor, def all_gather_multigpu(output_tensor_lists, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Gathers tensors from the whole group in a list. @@ -1318,7 +1322,8 @@ def all_gather_multigpu(output_tensor_lists, Note that ``len(input_tensor_list)`` needs to be the same for all the distributed processes calling this function. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1332,8 +1337,8 @@ def all_gather_multigpu(output_tensor_lists, output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allgather(output_tensor_lists, input_tensor_list) else: work = group.allgather(output_tensor_lists, input_tensor_list) @@ -1358,7 +1363,7 @@ def _tensor_to_object(tensor, tensor_size): return out -def all_gather_object(object_list, obj, group=group.WORLD): +def all_gather_object(object_list, obj, group=None): """ Gathers picklable objects from the whole group into a list. Similar to :func:`all_gather`, but Python objects can be passed in. Note that the object @@ -1427,7 +1432,7 @@ def all_gather_object(object_list, obj, group=group.WORLD): ] # Allgather tensor sizes all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) + max_object_size = int(max(object_size_list).item()) # type: ignore # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) coalesced_output_tensor = torch.empty( @@ -1446,7 +1451,7 @@ def all_gather_object(object_list, obj, group=group.WORLD): object_list[i] = _tensor_to_object(tensor, tensor_size) -def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): +def gather_object(obj, object_gather_list=None, dst=0, group=None): """ Gathers picklable objects from the whole group in a single process. Similar to :func:`gather`, but Python objects can be passed in. Note that the @@ -1518,7 +1523,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): # gather, since each rank needs to broadcast a tensor of the same (maximal) # size. all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) + max_object_size = int(max(object_size_list).item()) # type: ignore # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) # Avoid populating output tensors if the result won't be gathered on this rank. @@ -1546,7 +1551,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): object_gather_list[i] = _tensor_to_object(tensor, tensor_size) -def broadcast_object_list(object_list, src, group=group.WORLD): +def broadcast_object_list(object_list, src, group=None): """ Broadcasts picklable objects in ``object_list`` to the whole group. Similar to :func:`broadcast`, but Python objects can be passed in. @@ -1739,7 +1744,7 @@ def scatter_object_list( def all_gather(tensor_list, tensor, - group=group.WORLD, + group=None, async_op=False): """ Gathers tensors from the whole group in a list. @@ -1750,7 +1755,8 @@ def all_gather(tensor_list, tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. tensor (Tensor): Tensor to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1795,8 +1801,8 @@ def all_gather(tensor_list, tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allgather([tensor_list], [tensor]) else: work = group.allgather([tensor_list], [tensor]) @@ -1808,7 +1814,7 @@ def all_gather(tensor_list, def all_gather_coalesced(output_tensor_lists, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Gathers input tensors from the whole group in a list in a coalesced manner. @@ -1820,7 +1826,8 @@ def all_gather_coalesced(output_tensor_lists, correctly-sized tensors to be used for output of the collective. input_tensor_list (list[Tensor]): Tensors to be broadcast from current process. At least one tensor has to be non empty. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1866,8 +1873,8 @@ def all_gather_coalesced(output_tensor_lists, output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.allgather_coalesced( output_tensor_lists, input_tensor_list) else: @@ -1894,7 +1901,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): def gather(tensor, gather_list=None, dst=0, - group=group.WORLD, + group=None, async_op=False): """ Gathers a list of tensors in a single process. @@ -1905,7 +1912,8 @@ def gather(tensor, tensors to use for gathered data (default is None, must be specified on the destination rank) dst (int, optional): Destination rank (default is 0) - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1932,8 +1940,8 @@ def gather(tensor, opts = GatherOptions() opts.rootRank = dst - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.gather(output_tensors, input_tensors, opts) else: group_dst_rank = _get_group_rank(group, dst) @@ -1949,7 +1957,7 @@ def gather(tensor, def scatter(tensor, scatter_list=None, src=0, - group=group.WORLD, + group=None, async_op=False): """ Scatters a list of tensors to all processes in a group. @@ -1962,7 +1970,8 @@ def scatter(tensor, scatter_list (list[Tensor]): List of tensors to scatter (default is None, must be specified on the source rank) src (int): Source rank (default is 0) - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1998,8 +2007,8 @@ def scatter(tensor, opts = ScatterOptions() opts.rootRank = src - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() work = default_pg.scatter(output_tensors, input_tensors, opts) else: group_src_rank = _get_group_rank(group, src) @@ -2015,7 +2024,7 @@ def scatter(tensor, def reduce_scatter_multigpu(output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduce and scatter a list of tensors to the whole group. Only nccl backend @@ -2049,7 +2058,8 @@ def reduce_scatter_multigpu(output_tensor_list, therefore ``len(input_tensor_lists[i])``) need to be the same for all the distributed processes calling this function. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2063,8 +2073,8 @@ def reduce_scatter_multigpu(output_tensor_list, opts = ReduceScatterOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.reduce_scatter( output_tensor_list, input_tensor_lists, @@ -2086,7 +2096,7 @@ def reduce_scatter_multigpu(output_tensor_list, def reduce_scatter(output, input_list, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces, then scatters a list of tensors to all processes in a group. @@ -2094,7 +2104,8 @@ def reduce_scatter(output, Arguments: output (Tensor): Output tensor. input_list (list[Tensor]): List of tensors to reduce and scatter. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2110,8 +2121,8 @@ def reduce_scatter(output, opts = ReduceScatterOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.reduce_scatter([output], [input_list], opts) else: work = group.reduce_scatter([output], [input_list], opts) @@ -2126,7 +2137,7 @@ def all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, - group=group.WORLD, + group=None, async_op=False): """ Each process splits input tensor and then scatters the split list @@ -2142,7 +2153,8 @@ def all_to_all_single(output, input_split_sizes: (list[Int], optional): Input split sizes for dim 0 if specified None or empty, dim 0 of ``input`` tensor must divide equally by ``world_size``. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2206,8 +2218,8 @@ def all_to_all_single(output, output_split_sizes = [] if output_split_sizes is None else output_split_sizes input_split_sizes = [] if input_split_sizes is None else input_split_sizes - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) else: work = group.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) @@ -2219,7 +2231,7 @@ def all_to_all_single(output, def all_to_all(output_tensor_list, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Each process scatters list of input tensors to all processes in a group and @@ -2229,7 +2241,8 @@ def all_to_all(output_tensor_list, output_tensor_list (list[Tensor]): List of tensors to be gathered one per rank. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -2297,8 +2310,8 @@ def all_to_all(output_tensor_list, _check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list") - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.alltoall(output_tensor_list, input_tensor_list, opts) else: work = group.alltoall(output_tensor_list, input_tensor_list, opts) @@ -2309,7 +2322,7 @@ def all_to_all(output_tensor_list, work.wait() -def barrier(group=group.WORLD, +def barrier(group=GroupMember.WORLD, async_op=False): """ Synchronizes all processes. @@ -2318,7 +2331,8 @@ def barrier(group=group.WORLD, if async_op is False, or if async work handle is called on wait(). Arguments: - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -2328,8 +2342,8 @@ def barrier(group=group.WORLD, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - default_pg = _check_default_pg() + if group is None: + default_pg = _get_default_group() work = default_pg.barrier() else: work = group.barrier() @@ -2380,7 +2394,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): global _pg_group_ranks - default_pg = _check_default_pg() + default_pg = _get_default_group() default_backend, default_store = _pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index a0412d52df0d..4181db799b28 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -733,6 +733,7 @@ class CatTransform(Transform): """ def __init__(self, tseq, dim=0, lengths=None, cache_size=0): assert all(isinstance(t, Transform) for t in tseq) + self.event_dim = max(t.event_dim for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] super(CatTransform, self).__init__(cache_size=cache_size) @@ -784,9 +785,20 @@ def log_abs_det_jacobian(self, x, y): for trans, length in zip(self.transforms, self.lengths): xslice = x.narrow(self.dim, start, length) yslice = y.narrow(self.dim, start, length) - logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) + logdetjac = trans.log_abs_det_jacobian(xslice, yslice) + if trans.event_dim < self.event_dim: + logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) + logdetjacs.append(logdetjac) start = start + length # avoid += for jit compat - return torch.cat(logdetjacs, dim=self.dim) + # Decide whether to concatenate or sum. + dim = self.dim + if dim >= 0: + dim = dim - x.dim() + dim = dim + self.event_dim + if dim < 0: + return torch.cat(logdetjacs, dim=dim) + else: + return sum(logdetjacs) @property def bijective(self): diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py new file mode 100644 index 000000000000..4b5ea52c5c1f --- /dev/null +++ b/torch/fx/experimental/const_fold.py @@ -0,0 +1,269 @@ +import operator +from typing import Dict, Set, List, Optional + +import torch.fx +from torch.fx.experimental.subgraph_creation_example import split_module +import re + + +def _make_tuple(x): + """ + Helper to convert x into a one item tuple if it's not a tuple already. + """ + return x if isinstance(x, tuple) else (x,) + + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + const_output_names: Optional[List[str]] = None, + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.const_output_names = const_output_names + self.has_folding_been_run = False + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if self.const_subgraph_module is None or self.const_output_names is None: + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. We _make_tuple here because + # single attr const fold subgraphs output a single Tensor while + # multiple outputs are returned as Tuple[Tensor,]. + folded_attrs = _make_tuple(self.const_subgraph_module()) + + # Look for output node from const folding subgraph and set attrs on the + # module with the results. + for i in range(len(folded_attrs)): + setattr( + self, self.const_output_names[i], torch.nn.Parameter(folded_attrs[i]) + ) + + +def split_const_subgraphs( + module: torch.nn.Module, +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + mod_traced = torch.fx.symbolic_trace(module) + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: Set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op == "get_attr" or set(node.all_input_nodes).issubset(const_nodes): + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + # Gather all names that are output from the const folding subgraph, which we + # will need to set dummy params on the module. + const_output_names: List[str] = [] + for node in split.submod_0.graph.nodes: + if node.op == "output": + # Note: we _make_tuple here because the output Node either contains + # a single output Node, or Tuple[Node], so this simplifies things. + const_output_names = [o.name for o in _make_tuple(node.args[0])] + break + + # Make sure the attr name we want to use is uniquely named in the module. + for i in range(len(const_output_names)): + # Add a suffix to make it easier to tell these were the result of const folding. + name = const_output_names[i] + "__CF" + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + const_output_names[i] = name + + # Now track the const_output_names to what name is used in the parent graph + # from the split via call_function getitem, to see what order it is passed + # into the non-const subgraph submod_1. First look to the parent module + # containing/calling into the const/non-const submodules to determine what + # the inputs are to each. Note if submod_0 had a single output then there is + # no getitem, and we can simply use the output from the call to submoid_0. + call_submod_0_args, call_submod_1_args = None, None + orig_ph_targets: List[str] = [] + for node in split.graph.nodes: + if node.op == "placeholder": + orig_ph_targets.append(node.target) + + if node.op == "call_module": + if node.target == "submod_0": + call_submod_0_args = node.args + continue + elif node.target == "submod_1": + call_submod_1_args = node.args + continue + assert call_submod_0_args is not None and call_submod_1_args is not None + + # Look through the args for the call into submod_1, and find the args that + # come from submod_0. Also look for get_attrs fed directly from the parent + # split into submod_1, i.e. those attrs that are not constant folded. + submod_1_input_idx_to_folded_attr_name: Dict[int, str] = {} + submod_1_input_idx_to_unfolded_attr_name: Dict[int, str] = {} + for i, node in enumerate(call_submod_1_args): + const_output_name = None + # If we only had a single output from submod_0 then we simply look for + # the call_module into it. + if len(const_output_names) == 1: + if node.op == "call_module" and node.target == "submod_0": + const_output_name = const_output_names[0] + + # Else we had multiple outputs from submod_0, so we need to look for all + # getitems from the call to it. + else: + if ( + node.op == "call_function" + and node.target == operator.__getitem__ + and node.args[0].target == "submod_0" + ): + const_output_name = const_output_names[node.args[1]] + + # Now map from the index of the constant into calling submod_1 and map + # to the constant output name, which we use for swapping in getattrs + # instead of placeholders in submod_1. + if const_output_name is not None: + submod_1_input_idx_to_folded_attr_name[i] = const_output_name + elif node.op == "get_attr": + submod_1_input_idx_to_unfolded_attr_name[i] = node.target + + assert len(submod_1_input_idx_to_folded_attr_name) == len(const_output_names) + + # Now we have a mapping from const output names to the index they are passed + # into submod_1, so swap in getattrs for placeholders. + ph_idx = 0 + for node in split.submod_1.graph.nodes: + if node.op != "placeholder": + continue + is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys() + is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys() + if not is_folded_attr and not is_unfolded_attr: + ph_idx += 1 + continue + + const_output_name = ( + submod_1_input_idx_to_folded_attr_name[ph_idx] + if is_folded_attr + else submod_1_input_idx_to_unfolded_attr_name[ph_idx] + ) + if is_folded_attr: + assert not hasattr(mod_traced, const_output_name) + # Use a dummy param, which will be overwritten when we run const folding. + setattr( + mod_traced, + const_output_name, + torch.nn.Parameter(torch.randn(1)), + ) + with split.submod_1.graph.inserting_before(node): + node.replace_all_uses_with(split.submod_1.graph.get_attr(const_output_name)) + split.submod_1.graph.erase_node(node) + ph_idx += 1 + + # We may need to reorder placeholders to ensure they have the same order as + # they do in the original split. + ph_idx = 0 + node = next(iter(split.submod_1.graph.nodes)) + while node.op != "root": + if node.op != "placeholder": + node = node.next + continue + + curr_orig_ph_target = orig_ph_targets[ph_idx] + ph_idx += 1 + # If this ph is in the correct position, nothing to do. + if curr_orig_ph_target == node.target: + node = node.next + continue + + # This ph is not in the correct order, so search the rest of the graph + # for the ph we expected and prepend it before the current ph. + later_node = node.next + while later_node.op != "root": + if ( + later_node.op == "placeholder" + and curr_orig_ph_target == later_node.target + ): + break + later_node = later_node.next + assert later_node.op != "root" + node.prepend(later_node) + # Note we do not increment node here, as it still may be in the wrong + # place (we just prepended the ph that should have come before it). + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside submod_0, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to submod_0 must be constants accessible via get_attr. + for node in split.submod_0.graph.nodes: + if node.op != "placeholder": + continue + in_node = next(n for n in call_submod_0_args if n.name == node.target) + assert in_node.op == "get_attr" + with split.submod_0.graph.inserting_before(node): + node.replace_all_uses_with(split.submod_0.graph.get_attr(in_node.target)) + split.submod_0.graph.erase_node(node) + + return FoldedGraphModule( + mod_traced, split.submod_1.graph, split.submod_0.graph, const_output_names + ) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index f8bc96b73c40..e6fc19a1394e 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -549,7 +549,7 @@ def illegal_shadowing_name(name : str) -> bool: _shadows_builtin_name(name) while candidate in self._used_names or illegal_shadowing_name(candidate): - match = re.match(r"(.*)_(\d+)", candidate) + match = re.match(r"(.*)_(\d+)$", candidate) if match is None: candidate = candidate + '_1' else: diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 19085f155020..01ce71afd388 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -308,11 +308,7 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { for (size_t i = 0; i < devices_.size(); ++i) { // Checking the work's corresponding CUDA events' status - auto ret = cudaEventQuery((*cudaEvents_)[i]); - if (ret != cudaSuccess && ret != cudaErrorNotReady) { - AT_CUDA_CHECK(ret); - } - if (ret == cudaErrorNotReady) { + if (!(*cudaEvents_)[i].query()) { return false; } } diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index d0cfa5f80512..4b07682b1af7 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -831,7 +831,7 @@ def extra_repr(self) -> str: class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. - See reference: Attention Is All You Need + See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O @@ -849,7 +849,7 @@ class MultiheadAttention(Module): vdim: total number of features in value. Default: None. Note: if kdim and vdim are None, they will be set to embed_dim such that - query, key, and value have the same number of features. + query, key, and value have the same number of features. Examples:: diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 3c8c626047dc..8256b111b988 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -15,12 +15,15 @@ class PixelShuffle(Module): `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ by Shi et. al (2016) for more details. + Note that this function can take inputs with any number of batch dimensions: + :math:`(L, H_{in}, W_{in})`, :math:`(N, L, H_{in}, W_{in})`, :math:`(N_1, N_2, L, H_{in}, W_{in})`, etc. + Args: upscale_factor (int): factor to increase spatial resolution by Shape: - - Input: :math:`(N, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2` - - Output: :math:`(N, C, H_{out}, W_{out})` where + - Input: :math:`(*, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2` + - Output: :math:`(*, C, H_{out}, W_{out})` where :math:`H_{out} = H_{in} \times \text{upscale\_factor}` and :math:`W_{out} = W_{in} \times \text{upscale\_factor}` diff --git a/torch/serialization.py b/torch/serialization.py index 7ae6abafa232..ebc5d0a08541 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -524,7 +524,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): deserialization methods using :func:`torch.serialization.register_package`. Args: - f: a file-like object (has to implement :meth:`read`, :meth`readline`, :meth`tell`, and :meth`seek`), + f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), or a string or os.PathLike object containing a file name map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f8280f9fb57d..6577b1c4559f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -21,6 +21,7 @@ import warnings import random import contextlib +import shutil import socket import subprocess import time @@ -300,11 +301,11 @@ def run_tests(argv=UNITTEST_ARGS): if IS_WINDOWS: @contextmanager - def TemporaryFileName(): + def TemporaryFileName(dir=None): # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile # opens the file, and it cannot be opened multiple times in Windows. To support Windows, # close the file after creation and try to remove it manually - f = tempfile.NamedTemporaryFile(delete=False) + f = tempfile.NamedTemporaryFile(delete=False, dir=dir) try: f.close() yield f.name @@ -312,10 +313,27 @@ def TemporaryFileName(): os.unlink(f.name) else: @contextmanager # noqa: T484 - def TemporaryFileName(): - with tempfile.NamedTemporaryFile() as f: + def TemporaryFileName(dir=None): + with tempfile.NamedTemporaryFile(dir=dir) as f: yield f.name +if IS_WINDOWS: + @contextmanager + def TemporaryDirectoryName(suffix=None): + # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely, + # so we first create the directory using mkdtemp and then remove it manually + try: + dir_name = tempfile.mkdtemp(suffix=suffix) + yield dir_name + finally: + shutil.rmtree(dir_name) +else: + @contextmanager # noqa: T484 + def TemporaryDirectoryName(suffix=None): + with tempfile.TemporaryDirectory(suffix=suffix) as d: + yield d + +IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' def _check_module_exists(name): r"""Returns if a top-level module with :attr:`name` exists *without** diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index 376fdb8049b9..4f14584af3b1 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -285,16 +285,6 @@ def test_invalid_devices(self): ) ) - with self.assertRaisesRegex( - RuntimeError, r"CPU device index must be -1 or zero, got 2" - ): - list( - self._create_remote_module_iter( - "{}/cpu:2".format(dst_worker_name), - modes=[ModuleCreationMode.MODULE_CTOR], - ) - ) - with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): list( self._create_remote_module_iter( diff --git a/torch/testing/check_kernel_launches.py b/torch/testing/check_kernel_launches.py index 091f1be98561..c274316b54fe 100644 --- a/torch/testing/check_kernel_launches.py +++ b/torch/testing/check_kernel_launches.py @@ -18,12 +18,13 @@ # But this should be sufficient to detect and fix most problem # instances and can be refined before the test is made binding kernel_launch_regex = re.compile(r""" - >>> # Identifies kernel launch + ^.*>>> # Identifies kernel launch \s* # Maybe some whitespace (includes newlines) \([^;]+\); # And then arguments in parens and semi-colon (?! # Negative lookahead: we trigger if we don't find the launch guard \s* # Maybe some whitespace (includes newlines) \\? # 0 or 1 backslashes (for launches in preprocessor macros) + \s* # Maybe some whitespace (includes newlines) (?:[0-9]+: )? # Detects and ignores a line numbering, if present \s* # Maybe some whitespace (includes newlines) C10_CUDA_KERNEL_LAUNCH_CHECK\(\); # Kernel launch guard!