diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index ebaddda7ca26..a1d045ccd400 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -18,7 +18,11 @@ ("clang", [ ("5", [ ("3.6", [ - ("asan", [XImportant(True)]), + ("asan", [ + (True, [ + ("shard_test", [XImportant(True)]), + ]), + ]), ]), ]), ("7", [ @@ -45,14 +49,14 @@ ]), ("10.2", [ ("3.6", [ - ("important", [X(True)]), + ("shard_test", [XImportant(True)]), ("libtorch", [X(True)]), ]), ]), ("11.0", [ ("3.8", [ X(True), - ("libtorch", [XImportant(True)]) + ("libtorch", [XImportant(True)]), ]), ]), ]), @@ -158,6 +162,7 @@ def child_constructor(self): "libtorch": LibTorchConfigNode, "important": ImportantConfigNode, "build_only": BuildOnlyConfigNode, + "shard_test": ShardTestConfigNode, "cuda_gcc_override": CudaGccOverrideConfigNode, "coverage": CoverageConfigNode, "pure_torch": PureTorchConfigNode, @@ -195,7 +200,7 @@ def init2(self, node_name): self.props["is_asan"] = node_name def child_constructor(self): - return ImportantConfigNode + return ExperimentalFeatureConfigNode class ONNXConfigNode(TreeConfigNode): @@ -260,8 +265,8 @@ def init2(self, node_name): def child_constructor(self): return ExperimentalFeatureConfigNode -class BuildOnlyConfigNode(TreeConfigNode): +class BuildOnlyConfigNode(TreeConfigNode): def init2(self, node_name): self.props["build_only"] = node_name @@ -269,8 +274,15 @@ def child_constructor(self): return ExperimentalFeatureConfigNode -class CoverageConfigNode(TreeConfigNode): +class ShardTestConfigNode(TreeConfigNode): + def init2(self, node_name): + self.props["shard_test"] = node_name + + def child_constructor(self): + return ImportantConfigNode + +class CoverageConfigNode(TreeConfigNode): def init2(self, node_name): self.props["is_coverage"] = node_name @@ -290,7 +302,6 @@ def get_children(self): class XenialCompilerConfigNode(TreeConfigNode): - def modify_label(self, label): return label or "" @@ -304,7 +315,6 @@ def child_constructor(self): class BionicCompilerConfigNode(TreeConfigNode): - def modify_label(self, label): return label or "" diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index 3afe37b29f2d..0c03fac487d6 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -288,7 +288,6 @@ def instantiate_configs(): rocm_version = None if compiler_name == "cuda": cuda_version = fc.find_prop("compiler_version") - restrict_phases = ["build", "test1", "test2"] elif compiler_name == "rocm": rocm_version = fc.find_prop("compiler_version") @@ -311,7 +310,6 @@ def instantiate_configs(): parms_list.append("asan") python_version = fc.find_prop("pyver") parms_list[0] = fc.find_prop("abbreviated_pyver") - restrict_phases = ["build", "test1", "test2"] if is_onnx: parms_list.append("onnx") @@ -328,7 +326,11 @@ def instantiate_configs(): parallel_backend = fc.find_prop("parallel_backend") or None build_only = fc.find_prop("build_only") or False is_coverage = fc.find_prop("is_coverage") or False + shard_test = fc.find_prop("shard_test") or False # TODO: fix pure_torch python test packaging issue. + if shard_test: + restrict_phases = ["build"] if restrict_phases is None else restrict_phases + restrict_phases.extend(["test1", "test2"]) if build_only or is_pure_torch: restrict_phases = ["build"] if is_coverage and restrict_phases is None: diff --git a/.circleci/config.yml b/.circleci/config.yml index 1ee03badb92a..ef11e702b0c9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -142,7 +142,7 @@ commands: name: (Optional) Merge target branch no_output_timeout: "10m" command: | - if [ -n "$CIRCLE_PULL_REQUEST" ]; then + if [[ -n "$CIRCLE_PULL_REQUEST" && "$CIRCLE_BRANCH" != "nightly" ]]; then PR_NUM=$(basename $CIRCLE_PULL_REQUEST) CIRCLE_PR_BASE_BRANCH=$(curl -s https://api.github.com/repos/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pulls/$PR_NUM | jq -r '.base.ref') if [[ "${BUILD_ENVIRONMENT}" == *"xla"* || "${BUILD_ENVIRONMENT}" == *"gcc5"* ]] ; then @@ -6668,7 +6668,7 @@ workflows: build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test1 + name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test requires: - pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build filters: @@ -6677,21 +6677,7 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test2 - requires: - - pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test2" + build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6802,21 +6788,7 @@ workflows: build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 - requires: - - pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 + name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test requires: - pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build filters: @@ -6825,7 +6797,7 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2" + build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6842,7 +6814,7 @@ workflows: build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1 + name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test requires: - pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build filters: @@ -6851,21 +6823,7 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2 - requires: - - pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2" + build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6876,18 +6834,10 @@ workflows: build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1 - requires: - - pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build - build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2 + name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test requires: - pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build - build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2" + build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium diff --git a/.circleci/verbatim-sources/commands.yml b/.circleci/verbatim-sources/commands.yml index cec3b3c588e2..dfa4ee4d16de 100644 --- a/.circleci/verbatim-sources/commands.yml +++ b/.circleci/verbatim-sources/commands.yml @@ -103,7 +103,7 @@ commands: name: (Optional) Merge target branch no_output_timeout: "10m" command: | - if [ -n "$CIRCLE_PULL_REQUEST" ]; then + if [[ -n "$CIRCLE_PULL_REQUEST" && "$CIRCLE_BRANCH" != "nightly" ]]; then PR_NUM=$(basename $CIRCLE_PULL_REQUEST) CIRCLE_PR_BASE_BRANCH=$(curl -s https://api.github.com/repos/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pulls/$PR_NUM | jq -r '.base.ref') if [[ "${BUILD_ENVIRONMENT}" == *"xla"* || "${BUILD_ENVIRONMENT}" == *"gcc5"* ]] ; then diff --git a/.clang-tidy b/.clang-tidy index e062760cf75c..a540d67a130e 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,7 +1,6 @@ --- # NOTE there must be no spaces before the '-', so put the comma last. -InheritParentConfig: true -Checks: ' +Checks: '-*, bugprone-*, -bugprone-forward-declaration-namespace, -bugprone-macro-parentheses, @@ -18,7 +17,6 @@ cppcoreguidelines-*, -cppcoreguidelines-pro-type-union-access, -cppcoreguidelines-pro-type-vararg, -cppcoreguidelines-special-member-functions, --facebook-hte-RelativeInclude, hicpp-exception-baseclass, hicpp-avoid-goto, modernize-*, @@ -29,7 +27,7 @@ modernize-*, -modernize-use-trailing-return-type, performance-*, -performance-noexcept-move-constructor, -' + ' HeaderFilterRegex: 'torch/csrc/.*' AnalyzeTemporaryDtors: false CheckOptions: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b1b539788ba4..e54b5254e24c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -110,10 +110,10 @@ jobs: # Install dependencies pip install pyyaml wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - sudo apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-8 main" + sudo apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-11 main" sudo apt-get update - sudo apt-get install -y clang-tidy-8 - sudo update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-8 1000 + sudo apt-get install -y clang-tidy-11 + sudo update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-11 1000 - name: Run clang-tidy run: | set -eux diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index 504d3b931bc7..417b021df0cd 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -103,7 +103,7 @@ if "%USE_CUDA%"=="1" ( :: in PATH, and then pass the arguments to it. :: Currently, randomtemp is placed before sccache (%TMP_DIR_WIN%\bin\nvcc) :: so we are actually pretending sccache instead of nvcc itself. - curl -kL https://github.com/peterjc123/randomtemp-rust/releases/download/v0.2/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe + curl -kL https://github.com/peterjc123/randomtemp-rust/releases/download/v0.3/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe set RANDOMTEMP_EXECUTABLE=%TMP_DIR_WIN%\bin\nvcc.exe set CUDA_NVCC_EXECUTABLE=%TMP_DIR_WIN%\bin\randomtemp.exe set RANDOMTEMP_BASEDIR=%TMP_DIR_WIN%\bin diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat index cf7255ce3789..7669f6cfd91e 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat @@ -12,7 +12,7 @@ call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Minic if "%REBUILD%"=="" ( call conda install -y -q python=%PYTHON_VERSION% numpy cffi pyyaml boto3 call conda install -y -q -c conda-forge cmake - call conda install -y -q -c rdonnelly libuv + call conda install -y -q -c conda-forge libuv=1.39 ) :: Get installed libuv path diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 640ad0c181e4..136760263ca1 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -317,6 +317,11 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("median.dim_values", CppFunction::makeFallthrough()); m.impl("median.names_dim", CppFunction::makeFallthrough()); m.impl("median.names_dim_values", CppFunction::makeFallthrough()); + m.impl("nanmedian", CppFunction::makeFallthrough()); + m.impl("nanmedian.dim", CppFunction::makeFallthrough()); + m.impl("nanmedian.dim_values", CppFunction::makeFallthrough()); + m.impl("nanmedian.names_dim", CppFunction::makeFallthrough()); + m.impl("nanmedian.names_dim_values", CppFunction::makeFallthrough()); m.impl("min", CppFunction::makeFallthrough()); m.impl("min.dim", CppFunction::makeFallthrough()); m.impl("min.dim_min", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 3c82ecdc48c0..ce780f58f3a6 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -137,7 +137,6 @@ _(aten, _th_baddbmm) \ _(aten, _th_bmm) \ _(aten, _th_get_device) \ _(aten, _th_kthvalue) \ -_(aten, _th_median) \ _(aten, _th_mode) \ _(aten, _th_prod) \ _(aten, _th_sigmoid) \ @@ -463,6 +462,7 @@ _(aten, max_unpool3d_forward) \ _(aten, max_values) \ _(aten, mean) \ _(aten, median) \ +_(aten, nanmedian) \ _(aten, meshgrid) \ _(aten, min) \ _(aten, min_values) \ @@ -902,6 +902,7 @@ _(attr, maxnorm) \ _(attr, maximum) \ _(attr, mean) \ _(attr, median) \ +_(attr, nanmedian) \ _(attr, min) \ _(attr, min_indices) \ _(attr, min_val) \ diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 00e57ca63520..f7381813d0b5 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -77,22 +77,24 @@ namespace at { namespace cuda { #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR cuOccupancyMaxActiveBlocksPerMultiprocessor #endif -#define AT_FORALL_NVRTC(_) \ - _(nvrtcVersion) \ - _(nvrtcCreateProgram) \ - _(nvrtcDestroyProgram) \ - _(nvrtcGetPTXSize) \ - _(nvrtcGetPTX) \ - _(cuModuleLoadData) \ - _(cuModuleGetFunction) \ - _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR)\ - _(nvrtcGetErrorString) \ - _(nvrtcGetProgramLogSize) \ - _(nvrtcGetProgramLog) \ - _(cuLaunchKernel) \ - _(nvrtcCompileProgram) \ - _(cuCtxGetCurrent) \ - _(cuModuleUnload) \ +#define AT_FORALL_NVRTC(_) \ + _(nvrtcVersion) \ + _(nvrtcCreateProgram) \ + _(nvrtcAddNameExpression) \ + _(nvrtcDestroyProgram) \ + _(nvrtcGetPTXSize) \ + _(nvrtcGetPTX) \ + _(cuModuleLoadData) \ + _(cuModuleGetFunction) \ + _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \ + _(nvrtcGetErrorString) \ + _(nvrtcGetProgramLogSize) \ + _(nvrtcGetProgramLog) \ + _(cuLaunchKernel) \ + _(nvrtcCompileProgram) \ + _(cuCtxGetCurrent) \ + _(nvrtcGetLoweredName) \ + _(cuModuleUnload) \ _(cuDevicePrimaryCtxGetState) #endif diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9c3742c129de..0733b438e8e6 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -1425,9 +1426,9 @@ static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) } // Performs matrix norm -static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, +static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, optional opt_ord, IntArrayRef dim, bool keepdim, optional opt_dtype) { - Tensor result; + Tensor result_; auto ord = opt_ord.value_or(2.0).toDouble(); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "matrix norm only supports CPU AND CUDA device type, got: ", self.device().type()); @@ -1460,12 +1461,12 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); - result = std::get<1>(self_.permute(permutation).svd()).abs(); - result = _norm_min_max(result, ord, result.dim() - 1, keepdim); + result_ = std::get<1>(self_.permute(permutation).svd()).abs(); + result_ = _norm_min_max(result_, ord, result_.dim() - 1, keepdim); if (keepdim) { - result.unsqueeze_(-1); - result = result.permute(permutation_reverse); + result_.unsqueeze_(-1); + result_ = result_.permute(permutation_reverse); } } else { // abs(p) == infinity and abs(p) == 1 will perform identical reductions, except @@ -1482,12 +1483,14 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, dim_[1]--; } if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) { - result = self_.abs().sum(dim_[0], keepdim); - result = _norm_min_max(result, ord, dim_[1], keepdim); + result_ = self_.abs().sum(dim_[0], keepdim); + result_ = _norm_min_max(result_, ord, dim_[1], keepdim); } else { TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm"); } } + resize_output(result, result_.sizes()); + result.copy_(result_); return result; } @@ -1495,7 +1498,9 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, // This function mostly serves as a wrapper for at::norm, but it overrides a few cases // for numpy compatibility. These cases are corrected within this wrapper, rather than // in at::norm itself, to avoid breaking backward compatibility. -static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) { +static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) { + Tensor result_; + bool case_was_overridden = false; if (opt_ord.has_value()) { TORCH_INTERNAL_ASSERT(dim.size() == 1); auto ord = opt_ord.value().toDouble(); @@ -1504,20 +1509,15 @@ static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, // The ord = +/-infinity case is overridden because at::norm does not match numpy // when the input contains extreme values (like nan or +/-inf) or if the input // size is degenerate (like size(0), size(0, N), etc) - self_ = self_.abs(); - return _norm_min_max(self_, ord, dim[0], keepdim); + case_was_overridden = true; + self_ = self.abs(); + result_ = _norm_min_max(self_, ord, dim[0], keepdim); } else if ((self_.numel() == 0) && (ord < 0)) { // For negative orders with degenerate input sizes, at::norm's result does not - // match numpy. - Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim); - if (ord >= -1) { - // Result must be infinite in this case, and the simplest way to make that - // happen is to simply add infinity - result += INFINITY; - } else { - result = result.pow(1.0 / (ord + 1)); - } - return result; + // match numpy. It should always be infinity. + auto mask = make_dim_mask(dim[0], self_.dim()); + allocate_reduction_result(result, self_, mask, keepdim, result.scalar_type()); + return result.fill_(INFINITY); } } else { // If ord == None, need to check for unique dims because at::norm does not check it @@ -1527,11 +1527,16 @@ static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end(); TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")"); } - if (opt_dtype.has_value()) { - return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value()); - } else { - return at::norm(self, opt_ord, dim, keepdim); + if (!case_was_overridden) { + if (opt_dtype.has_value()) { + result_ = at::norm(self, opt_ord, dim, keepdim, opt_dtype.value()); + } else { + result_ = at::norm(self, opt_ord, dim, keepdim); + } } + resize_output(result, result_.sizes()); + result.copy_(result_); + return result; } static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) { @@ -1544,34 +1549,31 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional "dtype = ", dtype, ", out.dtype = ", result.scalar_type()); } int64_t ndim = self.dim(); - Tensor result_; if (opt_str_ord.has_value()) { // 'ord' is string auto str_ord = opt_str_ord.value(); check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype); if (str_ord == "fro") { - result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); + at::frobenius_norm_out(result, self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); } else if (str_ord == "nuc") { if (opt_dim.has_value()) { - result_ = at::nuclear_norm(self, opt_dim.value(), keepdim); + at::nuclear_norm_out(result, self, opt_dim.value(), keepdim); } else { - result_ = at::nuclear_norm(self, keepdim); + at::nuclear_norm_out(result, self, keepdim); } } } else { // 'ord' is int or None std::vector dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim); if (!opt_num_ord.has_value() || dim_.size() == 1) { - result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype); + _linalg_norm_vector_out(result, self, opt_num_ord, dim_, keepdim, opt_dtype); } else if (dim_.size() == 2) { - result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype); + _linalg_norm_matrix_out(result, self, opt_num_ord.value(), dim_, keepdim, opt_dtype); } else { TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is " "not 1-D or 2-D"); } } - resize_output(result, result_.sizes()); - result.copy_(result_); return result; } diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 7394365903ed..79746731a7ff 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -17,10 +17,7 @@ DEFINE_DISPATCH(topk_stub); namespace { -// maybe these days, one should define a random access iterator and use -// std::sort... /* Note from TH: - I cut and pasted (slightly adapted) the quicksort code from Sedgewick's 1978 "Implementing Quicksort Programs" article http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf @@ -35,7 +32,6 @@ namespace { Julien, November 12th 2013 */ - template void quick_select_template( TensorAccessor arr, @@ -272,6 +268,144 @@ std::tuple kthvalue_out_impl_cpu( return std::forward_as_tuple(values, indices); } +// Computes both the median and its index along dimension dim of the input +std::tuple median_with_indices_impl( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim, + bool ignore_nan) { + dim = at::maybe_wrap_dim(dim, self.dim()); + + int64_t size = self.dim() > 0 ? self.size(dim) : 1; + TORCH_CHECK( + size > 0, + "median() cannot compute median for a dimension of size 0 because ", + "the operation does not have an identity"); + + checkDeviceType("median", {values, indices}, self.device().type()); + checkScalarType("median", {indices, "indices", 1}, kLong); + checkSameType("median", {values, "values", 0}, {self, "self", 2}); + + std::vector out_shape = self.sizes().vec(); + if (self.dim() > 0) { + if (keepdim) { + out_shape[dim] = 1; + } else { + out_shape.erase(out_shape.begin() + dim); + } + } + + resize_output(values, out_shape); + resize_output(indices, out_shape); + + // Ensure #dim is the same for all tensors required for dim_apply + Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); + Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim); + Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim); + + // Make dim to reduce contiguous (stride=1) + if (in.stride(dim) > 1) { + in = in.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim).contiguous(); + vals = vals.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim); + inds = inds.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim); + dim = in.dim() - 1; + } + + AT_DISPATCH_ALL_TYPES(in.scalar_type(), "median_out", [&] { + dim_apply({in, vals, inds}, dim, [&](int64_t it, TensorList tl) { + // Make the current row to be reduced contiguous + scalar_t* ip = tl[0].data_ptr(); + + // For torch.median, search for NaN and return it if found + if (!ignore_nan) { + scalar_t* nanp = std::find_if(ip, ip + size, _isnan); + if (nanp != ip + size) { + *tl[1].data_ptr() = *nanp; + *tl[2].data_ptr() = nanp - ip; + return; + } + } + + // Vector of indices for indirectly partitioning input around median + std::vector idx(size); + auto first = idx.begin(); + auto last = idx.end(); + std::iota(first, last, 0); + + // We partition the input around the median indirectly using the indices + // vector so that nth points to the index of the median in the unmodified + // input tensor. + auto nth = first; + if (!ignore_nan) { + // If we got here, there are no nan values + nth += (size - 1) / 2; + std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) { + return ip[i] < ip[j] || (ip[i] == ip[j] && i < j); + }); + } else { + // For torch.nanmedian, compute median of non-nan values only + int64_t num_nan = std::count_if(ip, ip + size, _isnan); + nth += (size - num_nan - 1) / 2; + std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) { + return ip[i] < ip[j] || (ip[i] == ip[j] && i < j) || + (_isnan(ip[j]) && !_isnan(ip[i])); + }); + } + + *tl[1].data_ptr() = ip[*nth]; + *tl[2].data_ptr() = *nth; + }); + }); + + return std::forward_as_tuple(values, indices); +} + +// Computes the median of all values in the input +Tensor median_impl(const Tensor& self, bool ignore_nan) { + NoNamesGuard guard; + + int64_t size = self.numel(); + TORCH_CHECK( + size > 0, + "median() operation does not have an identity for empty input tensor"); + + // Clone the input tensor so we can partition it around the median value + Tensor in = self.clone(); + Tensor out = at::empty({}, self.options()); + + AT_DISPATCH_ALL_TYPES(in.scalar_type(), "median_cpu", [&] { + scalar_t* op = out.data_ptr(); + scalar_t* first = in.data_ptr(); + scalar_t* last = first + size; + + // For torch.median, if there are nan values return nan + if (!ignore_nan && std::any_of(first, last, _isnan)) { + *op = std::numeric_limits::quiet_NaN(); + return; + } + + scalar_t* median = first; + if (!ignore_nan) { + // If we got here, there are no nan values + median += (size - 1) / 2; + std::nth_element(first, median, last); + } else { + // For torch.nanmedian, compute median of non-nan values only + int64_t num_nan = std::count_if(first, last, _isnan); + median += (size - num_nan - 1) / 2; + std::nth_element(first, median, last, [](scalar_t a, scalar_t b) { + return a < b || (_isnan(b) && !_isnan(a)); + }); + } + + *op = *median; + }); + + return out; +} + } // namespace Tensor& quantile_out( @@ -451,43 +585,20 @@ std::tuple topk( return std::make_tuple(values, indices); } -// this does not reduce to median with dim because we don't want to copy twice -Tensor median_cpu(const Tensor& self) { - NoNamesGuard guard; - TORCH_CHECK(self.numel() > 0, "median cannot be called with empty tensor"); - if (self.dim() == 0 && self.numel() == 1) { - return self.clone(at::MemoryFormat::Contiguous); - } - auto tmp_values = self.clone(at::MemoryFormat::Contiguous).view(-1); - auto result = at::empty({1}, self.options()); - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "median", [&] { - // note, quick_select is 0 based while kthvalue is not - int64_t k = (tmp_values.size(0) - 1) / 2; - auto val_accessor = tmp_values.accessor(); - quick_select_template( - val_accessor, - k, - [](scalar_t x, scalar_t y) -> bool { - return ((_isnan(x) && !_isnan(y)) || (x > y)); - }, - [&](int64_t i, int64_t j) { - std::swap(val_accessor[i], val_accessor[j]); - }); - result.fill_(tmp_values[k]); - }); - return result.view({}); -} - -std::tuple median_out( +std::tuple median_out_cpu( Tensor& values, Tensor& indices, const Tensor& self, int64_t dim, bool keepdim) { - // note: kthvalue counts from 1..n - int64_t k = self.dim() > 0 ? (self.size(dim) + 1) / 2 : 1; - at::kthvalue_out(values, indices, self, k, dim, keepdim); - return std::forward_as_tuple(values, indices); + auto result = [&]() { + NoNamesGuard guard; + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/false); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; } std::tuple median_out( @@ -517,13 +628,63 @@ std::tuple median( return at::median(self, dimname_to_position(self, dim), keepdim); } +Tensor median_cpu(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/false); +} + +std::tuple nanmedian_out_cpu( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim) { + auto result = [&]() { + NoNamesGuard guard; + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/true); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; +} + +std::tuple nanmedian_out( + Tensor& values, + Tensor& indices, + const Tensor& self, + Dimname dim, + bool keepdim) { + return at::nanmedian_out( + values, indices, self, dimname_to_position(self, dim), keepdim); +} + +std::tuple nanmedian( + const Tensor& self, + int64_t dim, + bool keepdim) { + Tensor values = at::empty({0}, self.options()); + Tensor indices = at::empty({0}, self.options().dtype(kLong)); + at::nanmedian_out(values, indices, self, dim, keepdim); + return std::make_tuple(values, indices); +} + +std::tuple nanmedian( + const Tensor& self, + Dimname dim, + bool keepdim) { + return at::nanmedian(self, dimname_to_position(self, dim), keepdim); +} + +Tensor nanmedian_cpu(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/true); +} + std::tuple sort_out_cpu( Tensor& values, Tensor& indices, const Tensor& self, int64_t dim, - bool descending - ) { + bool descending) { values.resize_(self.sizes()).copy_(self); indices.resize_(self.sizes()); @@ -541,8 +702,7 @@ std::tuple sort_out_cpu( std::tuple sort_cpu( const Tensor& self, int64_t dim, - bool descending - ) { + bool descending) { Tensor values = at::empty({0}, self.options()); Tensor indices = at::empty({0}, self.options().dtype(kLong)); return sort_out_cpu(values, indices, self, dim, descending); diff --git a/aten/src/ATen/native/SortingUtils.h b/aten/src/ATen/native/SortingUtils.h index 6b83be84ce90..3223fd3a779d 100644 --- a/aten/src/ATen/native/SortingUtils.h +++ b/aten/src/ATen/native/SortingUtils.h @@ -40,7 +40,7 @@ void dim_apply(TensorList tensors, int64_t dim, Fn f) { }); } -// ensure we get good values and indices for kthvalue, mode, median +// ensure we get good values and indices for kthvalue, mode // this will always be with the reducing dim as 1-d inline void _reduction_with_indices_allocate_or_resize_output( Tensor& values, diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 3dbd52aa1905..bca831dc55e2 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -17,6 +17,8 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) { cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; }); } else if (dtype == ScalarType::BFloat16) { cpu_kernel(iter, [=](at::BFloat16 a) -> at::BFloat16 { return a; }); + } else if (dtype == ScalarType::ComplexHalf) { + cpu_kernel(iter, [=](c10::complex a) -> c10::complex { return a; }); } else if (isQIntType(dtype)) { AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] { cpu_kernel_vec( diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index e9dfe2d9285d..f39d66e816a1 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -442,6 +442,20 @@ inline magma_int_t magmaGeqrfOptimalBlocksize(magma_int_t m, magma_int_t return magma_get_sgeqrf_nb(m, n); } +template <> +inline magma_int_t magmaGeqrfOptimalBlocksize>( + magma_int_t m, + magma_int_t n) { + return magma_get_zgeqrf_nb(m, n); +} + +template <> +inline magma_int_t magmaGeqrfOptimalBlocksize>( + magma_int_t m, + magma_int_t n) { + return magma_get_cgeqrf_nb(m, n); +} + template<> void magmaGeqrf( magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, @@ -468,6 +482,70 @@ void magmaGeqrf( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaGeqrf>( + magma_int_t m, + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t* info, + bool is_v2) { + MagmaStreamSyncGuard guard; + if (!is_v2) { + magma_zgeqrf_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + info); + } else { + magma_zgeqrf2_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + info); + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaGeqrf>( + magma_int_t m, + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t* info, + bool is_v2) { + MagmaStreamSyncGuard guard; + if (!is_v2) { + magma_cgeqrf_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + info); + } else { + magma_cgeqrf2_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + info); + } + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaOrgqr( magma_int_t m, magma_int_t n, magma_int_t k, double* dA, magma_int_t ldda, @@ -486,6 +564,56 @@ void magmaOrgqr( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaOrgqr>( + magma_int_t m, + magma_int_t n, + magma_int_t k, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t nb, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zungqr_gpu( + m, + n, + k, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + nb, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaOrgqr>( + magma_int_t m, + magma_int_t n, + magma_int_t k, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t nb, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cungqr_gpu( + m, + n, + k, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + nb, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSymeig( magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, double* dA, magma_int_t ldda, @@ -1276,7 +1404,7 @@ std::tuple _qr_helper_cuda(const Tensor& self, bool some) { q_working_copy.narrow(-1, 0, self.size(-1)).copy_(self); r_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "qr_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cuda", [&]{ apply_qr(q_working_copy, r_working_copy, n_columns_q, infos); }); if (self.dim() > 2) { diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu new file mode 100644 index 000000000000..a0a9e35c916f --- /dev/null +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -0,0 +1,453 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include // only for THCRoundUp? +#include +#include +#include // AddOp + +#include +#include + +namespace at { +namespace native { + +namespace { + +// Finds the rank k element, and its index, of the values along dimension dim +template +__global__ void gatherKthValue( + cuda::detail::TensorInfo input, + index_t inputSliceSize, + index_t k, + index_t numInputSlices, + index_t inputWithinSliceStride, + cuda::detail::TensorInfo kthValue, + cuda::detail::TensorInfo indices) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of index_t + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Find the start offset for our slice + index_t sliceStartIndex = + cuda::detail::IndexToOffset::get(slice, input); + index_t kthValueSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, kthValue); + index_t indicesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, indices); + + scalar_t* inputSliceStart = &input.data[sliceStartIndex]; + scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + scalar_t kValue = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t, + false>( + inputSliceStart, + k, + inputSliceSize, + inputWithinSliceStride, + smem, + &kValue); + + // Find the index of the k-th highest element + index_t kValueIndex = 0; + bool foundKValue = false; + + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) + : static_cast(0); + bool isKValue = inRange && + ((v == kValue) || + (THCNumerics::isnan(v) && + THCNumerics::isnan(kValue))); + if (isKValue) { + kValueIndex = i; + foundKValue = true; + break; + } + } + + if (foundKValue) { + kthValueSliceStart[0] = kValue; + indicesSliceStart[0] = kValueIndex; + } +} + +// CUDA kernel to find the median, and its index, of the values along dimension dim +template +__global__ void gatherMedian( + cuda::detail::TensorInfo values, + cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo input, + index_t inputSliceSize, + index_t numInputSlices, + index_t inputWithinSliceStride, + bool ignore_nan) { + // Shared memory for the subroutine RadixSelect. Note that RadixSelect converts the + // floating point type to int with the same relative ordering. + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Finds the start offset for our slice + index_t valuesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, values); + index_t indicesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, indices); + index_t inputSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, input); + + scalar_t* valuesSliceStart = &values.data[valuesSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + scalar_t* inputSliceStart = &input.data[inputSliceStartIndex]; + + index_t nan_count = 0; + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]); + nan_count += THCNumerics::isnan(val) ? 1 : 0; + } + + // Counts number of nan values + // This code performs a parallel sum reduction (not the most efficient code) + __shared__ int64_t num_nan; + if (threadIdx.x == 0) { + num_nan = 0; + } + __syncthreads(); + if (nan_count > 0) { + atomicAdd(&num_nan, nan_count); + } + __syncthreads(); + + // For torch.median, if we found nan set k to last index so the computed value + // is nan, otherwise set k to the middle element of the non-nan values + index_t k = (!ignore_nan && num_nan > 0) ? inputSliceSize - 1 + : (inputSliceSize - num_nan - 1) / 2; + + // Find the median + scalar_t median = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t, + false>( + inputSliceStart, + k + 1, + inputSliceSize, + inputWithinSliceStride, + smem, + &median); + + valuesSliceStart[0] = median; + + // Find the index of the median value in the slice + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]); + if (val == median || + (THCNumerics::isnan(val) && + THCNumerics::isnan(median))) { + indicesSliceStart[0] = i; + break; + } + } +} + +struct KthValueLauncher { + int64_t k; + + KthValueLauncher(int64_t k) : k(k) {} + + template + inline void launch( + cuda::detail::TensorInfo values_info, + int collapse_values_dim, + cuda::detail::TensorInfo indices_info, + int collapse_indices_dim, + cuda::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block(std::min( + THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); + auto stream = at::cuda::getCurrentCUDAStream(); + gatherKthValue<<>>( + self_info, + slice_size, + k, + num_slices, + /* The actual dimension that the k-selection is running in */ + /* may have changed from collapseDims() */ + self_info.strides[collapse_self_dim], + values_info, + indices_info); + } +}; + +struct MedianLauncher { + bool ignore_nan; + + MedianLauncher(bool ignore_nan) : ignore_nan(ignore_nan) {} + + template + inline void launch( + cuda::detail::TensorInfo values_info, + int collapse_values_dim, + cuda::detail::TensorInfo indices_info, + int collapse_indices_dim, + cuda::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block(std::min( + THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); + auto stream = at::cuda::getCurrentCUDAStream(); + gatherMedian<<>>( + values_info, + indices_info, + self_info, + slice_size, + num_slices, + self_info.strides[collapse_self_dim], + ignore_nan); + } +}; + +template +void kthvalue_cuda_template( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim()); + int64_t slicesize = self.size(dim); + // FIXME: This seems bogus, I only do this because it was the old behaviour. + // The reductions are fine, as long as the axis being reduced along + // isn't of 0 elements (and the output has elements). + TORCH_CHECK( + self.numel() > 0, + "cannot perform reduction function kthvalue", + " on tensor with no elements because the operation does not have an identity"); + TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range"); + + _reduction_with_indices_allocate_or_resize_output( + values, indices, self, dim, keepdim); + if (self.dim() == 0 && self.numel() == 1) { + values.copy_(self); + indices.zero_(); + return; + } + + TORCH_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + // Based on required index size, run the algorithm with the + // appropriate index type + if (cuda::detail::canUse32BitIndexMath(self) && + cuda::detail::canUse32BitIndexMath(values) && + cuda::detail::canUse32BitIndexMath(indices)) { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + } else { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + } + + if (!keepdim) { + values.squeeze_(dim); + indices.squeeze_(dim); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +std::tuple kthvalue_out_impl_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] { + kthvalue_cuda_template( + values, indices, self, k, dim, keepdim); + }); + return std::forward_as_tuple(values, indices); +} + +std::tuple median_with_indices_impl( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim, + bool ignore_nan) { + NoNamesGuard guard; + + dim = at::maybe_wrap_dim(dim, self.dim()); + Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); + + int64_t size = in.size(dim); + TORCH_CHECK( + size > 0, + "median() cannot compute median for a dimension of size 0 because ", + "the operation does not have an identity"); + + checkDeviceType("median", {values, indices}, self.device().type()); + checkScalarType("median", {indices, "indices", 1}, kLong); + checkSameType("median", {values, "values", 0}, {self, "self", 2}); + + TORCH_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "median() cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + std::vector out_shape = self.sizes().vec(); + if (self.dim() > 0) { + if (keepdim) { + out_shape[dim] = 1; + } else { + out_shape.erase(out_shape.begin() + dim); + } + } + + values.resize_(out_shape); + indices.resize_(out_shape); + + // Only launch kernel for non-empty tensors + if (self.numel() > 0) { + // Ensure #dim is the same for all tensors required for reduction + Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim); + Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim); + + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.scalar_type(), "median_out_impl", [&] { + if (cuda::detail::canUse32BitIndexMath(vals) && + cuda::detail::canUse32BitIndexMath(inds) && + cuda::detail::canUse32BitIndexMath(in)) { + run_launcher( + vals, inds, in, dim, MedianLauncher(ignore_nan)); + } else { + run_launcher( + vals, inds, in, dim, MedianLauncher(ignore_nan)); + } + }); + + AT_CUDA_CHECK(cudaGetLastError()); + } + + guard.reset(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + + return std::forward_as_tuple(values, indices); +} + +Tensor median_impl(const Tensor& self, bool ignore_nan) { + NoNamesGuard guard; + + int64_t size = self.numel(); + TORCH_CHECK(size > 0, "median() input tensor cannot be empty"); + + // Sort input tensor to efficiently query for median element + Tensor sorted = std::get<0>(self.flatten().sort()); + + if (!ignore_nan) { + // For torch.median return either the middle element or nan (sorted as + // largest) if there are any + int64_t k = (size - 1) / 2; + return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]); + } else { + // For torch.nanmedian return the middle element among the non-nan values + Tensor k = ((size - 1) - sorted.isnan().sum()) / 2; + return sorted[k.toType(kLong)]; + } +} + +} // namespace + +// Mark: kthvalue + +std::tuple kthvalue_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + auto result = [&]() { + NoNamesGuard guard; + return kthvalue_out_impl_cuda(values, indices, self, k, dim, keepdim); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; +} + +// Mark: median + +std::tuple median_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim) { + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/false); +} + +Tensor median_cuda(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/false); +} + +std::tuple nanmedian_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim) { + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/true); +} + +Tensor nanmedian_cuda(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/true); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu deleted file mode 100644 index 953239197a66..000000000000 --- a/aten/src/ATen/native/cuda/SortingKthValue.cu +++ /dev/null @@ -1,259 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include // only for THCRoundUp? -#include -#include -#include // AddOp - -#include -#include -#include - -namespace at { -namespace native { - -namespace { - -template -__global__ void gatherKthValue( - cuda::detail::TensorInfo input, - index_t inputSliceSize, - index_t k, - - index_t numInputSlices, - index_t inputWithinSliceStride, - - cuda::detail::TensorInfo kthValue, - cuda::detail::TensorInfo indices) { - // Indices are limited to integer fp precision, so counts can fit in - // int32, regardless of index_t - __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit - - index_t slice = getLinearBlockId(); - if (slice >= numInputSlices) { - return; - } - - // Find the start offset for our slice - index_t sliceStartIndex = - cuda::detail::IndexToOffset::get(slice, input); - index_t kthValueSliceStartIndex = - cuda::detail::IndexToOffset::get(slice, kthValue); - index_t indicesSliceStartIndex = - cuda::detail::IndexToOffset::get(slice, indices); - - scalar_t* inputSliceStart = &input.data[sliceStartIndex]; - scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex]; - int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; - - // Find the k-th highest element in our input - scalar_t kValue = static_cast(0); - radixSelect< - scalar_t, - typename TopKTypeConfig::RadixType, - index_t, - false>( - inputSliceStart, - k, - inputSliceSize, - inputWithinSliceStride, - smem, - &kValue); - - // Find the index of the k-th highest element - index_t kValueIndex = 0; - bool foundKValue = false; - - for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { - bool inRange = (i < inputSliceSize); - scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) - : static_cast(0); - bool isKValue = inRange && - ((v == kValue) || - (THCNumerics::isnan(v) && - THCNumerics::isnan(kValue))); - if (isKValue) { - kValueIndex = i; - foundKValue = true; - break; - } - } - - if (foundKValue) { - kthValueSliceStart[0] = kValue; - indicesSliceStart[0] = kValueIndex; - } -} - -struct KthValueLauncher { - int64_t k; - - KthValueLauncher(int64_t k) : k(k) {} - - template - inline void launch( - cuda::detail::TensorInfo values_info, - int collapse_values_dim, - cuda::detail::TensorInfo indices_info, - int collapse_indices_dim, - cuda::detail::TensorInfo self_info, - int collapse_self_dim, - int64_t num_slices, - int64_t slice_size) { - dim3 grid; - if (!getGridFromTiles(num_slices, grid)) { - AT_ERROR("slices are too many"); - } - - dim3 block( - std::min(THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); - auto stream = at::cuda::getCurrentCUDAStream(); - gatherKthValue<<>>( - self_info, - slice_size, - k, - num_slices, - /* The actual dimension that the k-selection is running in */ - /* may have changed from collapseDims() */ - self_info.strides[collapse_self_dim], - values_info, - indices_info); - } -}; - -template -void kthvalue_cuda_template( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t k, - int64_t dim_, - bool keepdim) { - int64_t dim = maybe_wrap_dim(dim_, self.dim()); - int64_t slicesize = self.size(dim); - // FIXME: This seems bogus, I only do this because it was the old behaviour. - // The reductions are fine, as long as the axis being reduced along - // isn't of 0 elements (and the output has elements). - TORCH_CHECK( - self.numel() > 0, - "cannot perform reduction function kthvalue", - " on tensor with no elements because the operation does not have an identity"); - TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range"); - - _reduction_with_indices_allocate_or_resize_output( - values, indices, self, dim, keepdim); - if (self.dim() == 0 && self.numel() == 1) { - values.copy_(self); - indices.zero_(); - return; - } - - TORCH_CHECK( - self.dim() <= MAX_TENSORINFO_DIMS, - "cannot operate on more than ", - MAX_TENSORINFO_DIMS, - " dimensions"); - - // Based on required index size, run the algorithm with the - // appropriate index type - if (cuda::detail::canUse32BitIndexMath(self) && - cuda::detail::canUse32BitIndexMath(values) && - cuda::detail::canUse32BitIndexMath(indices)) { - run_launcher( - values, indices, self, dim, KthValueLauncher(k)); - } else { - run_launcher( - values, indices, self, dim, KthValueLauncher(k)); - } - - if (!keepdim) { - values.squeeze_(dim); - indices.squeeze_(dim); - } - - AT_CUDA_CHECK(cudaGetLastError()); -} - -// this does not reduce to median with dim because we don't want to copy twice -template -Tensor median_cuda_template(const Tensor& self) { - TORCH_CHECK(self.numel() > 0, "median cannot be called with empty tensor"); - if (self.dim() == 0 && self.numel() == 1) { - return self.clone(at::MemoryFormat::Contiguous); - } - auto self_copy = self.clone(at::MemoryFormat::Contiguous).view(-1); - auto values = at::empty({1}, self.options()); - auto indices = at::empty({1}, self.options().dtype(kLong)); - TORCH_CHECK( - self.dim() <= MAX_TENSORINFO_DIMS, - "cannot operate on more than ", - MAX_TENSORINFO_DIMS, - " dimensions"); - - // Based on required index size, run the algorithm with the - // appropriate index type - if (cuda::detail::canUse32BitIndexMath(self) && - cuda::detail::canUse32BitIndexMath(values) && - cuda::detail::canUse32BitIndexMath(indices)) { - run_launcher( - values, - indices, - self_copy, - 0, - KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based - } else { - run_launcher( - values, - indices, - self_copy, - 0, - KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based - } - return values.view({}); -} - -} // namespace - -static std::tuple kthvalue_out_impl_cuda( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t k, - int64_t dim, - bool keepdim) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] { - kthvalue_cuda_template(values, indices, self, k, dim, keepdim); - }); - return std::forward_as_tuple(values, indices); -} - -std::tuple kthvalue_out_cuda( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t k, - int64_t dim, - bool keepdim) { - auto result = [&]() { - NoNamesGuard guard; - return kthvalue_out_impl_cuda(values, indices, self, k, dim, keepdim); - }(); - namedinference::propagate_names_for_reduction(values, self, dim, keepdim); - namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); - return result; -} - -Tensor median_cuda(const Tensor& self) { - NoNamesGuard guard; - return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "median", [&] { - return median_cuda_template(self); - }); -} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8559eb8bdc60..f2b02580a6d3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2307,17 +2307,48 @@ - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) +- func: median(Tensor self) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU: median_cpu + CUDA: median_cuda + - func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full variants: function, method - func: median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + dispatch: + CPU: median_out_cpu + CUDA: median_out_cuda - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method - func: median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +- func: nanmedian(Tensor self) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU: nanmedian_cpu + CUDA: nanmedian_cuda + +- func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: full + variants: function, method + +- func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + dispatch: + CPU: nanmedian_out_cpu + CUDA: nanmedian_out_cuda + +- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + +- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + - func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full variants: function, method @@ -6088,13 +6119,6 @@ use_c10_dispatcher: full variants: method, function -- func: median(Tensor self) -> Tensor - use_c10_dispatcher: full - variants: method, function - dispatch: - CPU: median_cpu - CUDA: median_cuda - - func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False) -> Tensor diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 96453255112d..46258e1b2961 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -16,7 +16,7 @@ at::Tensor embedding_bag_4bit_helper( const at::Tensor& weight, const at::Tensor& indices, const c10::optional& offsets_in, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { @@ -38,10 +38,10 @@ at::Tensor embedding_bag_4bit_helper( auto weight_contig = weight.contiguous(); uint8_t* input_data = weight_contig.data_ptr(); - // Get compressed indices for sparse op. + // Get compressed indices for pruned_weights op. int32_t* compressed_indices_mapping_data = nullptr; int compressed_index_size = 0; - if (sparse) { + if (pruned_weights) { compressed_index_size = compressed_indices_mapping.value().numel(); compressed_indices_mapping_data = compressed_indices_mapping.value().data_ptr(); @@ -77,7 +77,7 @@ at::Tensor embedding_bag_4bit_helper( const int index_size = indices.numel(); constexpr int prefetch_distance = 16; #ifdef USE_FBGEMM - if (!sparse) { + if (!pruned_weights) { // Generate the fbgemm kernel auto kernel_64_ = fbgemm::GenerateEmbeddingSpMDMNBit( /*bit rate=*/4, @@ -158,7 +158,7 @@ at::Tensor embedding_bag_4bit_helper( for (int i = 0; i < lengths_data[m]; ++i, ++current) { int64_t idx; - if (!sparse) { + if (!pruned_weights) { idx = indices_data[current]; TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data"); } else { @@ -201,7 +201,7 @@ at::Tensor embedding_bag_byte_helper( const at::Tensor& packed_w, const at::Tensor& indices, const c10::optional& offsets_in, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, bool include_last_offset) { TORCH_CHECK( @@ -301,14 +301,14 @@ at::Tensor embedding_bag_byte_helper( at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( const at::Tensor& indices, const c10::optional& offsets_in, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, bool include_last_offset) { return embedding_bag_byte_helper( packed_w, indices, offsets_in, - sparse, + pruned_weights, per_sample_weights_, include_last_offset); } @@ -316,7 +316,7 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( const at::Tensor& indices, const c10::optional& offsets_in, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { @@ -324,7 +324,7 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( packed_w, indices, offsets_in, - sparse, + pruned_weights, per_sample_weights_, compressed_indices_mapping, include_last_offset); @@ -340,7 +340,7 @@ Tensor embedding_bag_byte_rowwise_offsets( const c10::optional& offsets_in, const bool /* scale_grad_by_freq */, const int64_t /* mode */, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, bool include_last_offset) { TORCH_CHECK(weight.scalar_type() == at::kByte); @@ -349,7 +349,7 @@ Tensor embedding_bag_byte_rowwise_offsets( weight, indices, offsets_in, - sparse, + pruned_weights, per_sample_weights_, include_last_offset); } @@ -360,7 +360,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( const c10::optional& offsets_in, const bool /* scale_grad_by_freq */, const int64_t /* mode */, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { @@ -368,7 +368,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( weight, indices, offsets_in, - sparse, + pruned_weights, per_sample_weights_, compressed_indices_mapping, include_last_offset); @@ -383,18 +383,22 @@ class QEmbeddingBag final { const c10::optional& offsets, const bool /* scale_grad_by_freq */, const int64_t /* mode */, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { if (bit_rate == 8) { return packed_weight->embeddingbag_byte( - indices, offsets, sparse, per_sample_weights_, include_last_offset); + indices, + offsets, + pruned_weights, + per_sample_weights_, + include_last_offset); } else if (bit_rate == 4) { return packed_weight->embeddingbag_4bit( indices, offsets, - sparse, + pruned_weights, per_sample_weights_, compressed_indices_mapping, include_last_offset); @@ -411,13 +415,13 @@ class QEmbedding final { static at::Tensor run( const c10::intrusive_ptr& packed_weight, const Tensor& indices, - bool sparse) { + bool pruned_weights) { const auto offsets_size = indices.numel(); at::Tensor offsets = at::arange(0, offsets_size, at::kLong); at::Tensor output; if (bit_rate == 8) { return packed_weight->embeddingbag_byte( - indices, offsets, sparse, c10::nullopt, false); + indices, offsets, pruned_weights, c10::nullopt, false); } else { TORCH_INTERNAL_ASSERT( "Currently only support 8-bit embedding quantization"); diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 1ab399da88e6..c16950a64c77 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -111,11 +111,11 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool sparse=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor")); diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 95e0e46bf79a..9cdd46a4ea21 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -344,7 +344,7 @@ class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):     def forward(self):         return self.op_func(self.input_one) -op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) +op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) if __name__ == "__main__":     op_bench.benchmark_runner.main() @@ -388,10 +388,10 @@ class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): ``` #### Part 3. Register a List of Operators -To register multiple operators, we introduced the `generate_pt_tests_from_list` function which takes three parameters. First, the list of operators. Second,the configs. Third, the benchmark class.   +To register multiple operators, we introduced the `generate_pt_tests_from_op_list` function which takes three parameters. First, the list of operators. Second,the configs. Third, the benchmark class.   Here is an example: ``` -op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) +op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) ``` diff --git a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py index e80de5ebb619..335be21bcd17 100644 --- a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py @@ -1,7 +1,7 @@ import operator_benchmark as op_bench import torch -import torch.nn.quantized.dynamic as nnqd +import torch.nn.quantized as nnq import numpy from . import configs @@ -11,7 +11,7 @@ class QEmbeddingBagBenchmark(op_bench.TorchBenchmarkBase): def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_last_offset, device): - self.embedding = nnqd.EmbeddingBag( + self.embedding = nnq.EmbeddingBag( num_embeddings=embeddingbags, embedding_dim=dim, mode=mode, diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8e358c9503f7..49c6ad78ee4f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -501,8 +501,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) endif() endif() - if(USE_CUDA) - list(APPEND Caffe2_GPU_SRCS + if(USE_CUDA OR USE_ROCM) + list(APPEND Caffe2_GPU_HIP_JIT_FUSERS_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp @@ -553,6 +553,10 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/type.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp ) + endif() + + if(USE_CUDA) + list(APPEND Caffe2_GPU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) if(MSVC) # Delay load nvcuda.dll so we can import torch compiled with cuda on a CPU-only machine @@ -570,13 +574,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) endif() if(USE_ROCM) - list(APPEND Caffe2_HIP_SRCS - ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp - ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp - ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp - ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp - ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp - ) + list(APPEND Caffe2_HIP_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) if(USE_NCCL) list(APPEND Caffe2_HIP_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc index 3f70e96fffc8..7802cb0757fc 100644 --- a/caffe2/core/plan_executor.cc +++ b/caffe2/core/plan_executor.cc @@ -17,10 +17,18 @@ C10_DEFINE_bool( "If used we will handle exceptions in executor threads. " "This avoids SIGABRT but may cause process to deadlock"); +C10_DEFINE_int( + caffe2_plan_executor_exception_timeout, + 60, + "Number of seconds to wait for concurrent threads to stop on exception" + "before terminating."); + namespace caffe2 { namespace { +// ExceptionWrapper holds an exception. If exception pointers are being used, +// it'll hold the original exception pointer otherwise just the message. class ExceptionWrapper { public: ExceptionWrapper() : hasException_(false) {} @@ -39,6 +47,10 @@ class ExceptionWrapper { #endif } + const std::string& what() const { + return exceptionMsg_; + } + operator bool() { return hasException_; } @@ -51,6 +63,33 @@ class ExceptionWrapper { std::string exceptionMsg_; }; +// ExceptionWrapperTerminate terminates the program with the specified +// exception. This preserves the exception ptr and ExceptionTracer will +// correctly grab it on exit. +class ExceptionWrapperTerminate { + public: + explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew) : ew_(std::move(ew)) {} + + ~ExceptionWrapperTerminate() { + ew_.rethrowException(); + } + + private: + ExceptionWrapper ew_; +}; + +// ScopeExitGuard runs the provided function when it's destructed. +class ScopeExitGuard { + public: + explicit ScopeExitGuard(std::function&& f) : f_(std::move(f)) {} + ~ScopeExitGuard() { + f_(); + } + + private: + std::function f_; +}; + struct NetDefInfo { const NetDef* netDef; // in order to keep the "override existing nets" on the top-level workflow, @@ -460,9 +499,17 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { << " with " << step.substep().size() << " concurrent substeps"; std::atomic next_substep{0}; - std::mutex exception_mutex; + std::condition_variable cv; + std::mutex exception_mutex; // exception_mutex protects done and first_exception + int done{0}; ExceptionWrapper first_exception; auto worker = [&]() { + ScopeExitGuard on_exit([&] { + std::lock_guard guard(exception_mutex); + done += 1; + cv.notify_all(); + }); + auto num_substeps = compiledStep->recurringSubsteps.size(); int substep_id = next_substep++ % num_substeps; if (compiledStep->gotFailure) { @@ -492,6 +539,8 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { } }; + std::unique_lock guard(exception_mutex); + std::vector threads; auto numThreads = compiledStep->recurringSubsteps.size(); if (step.has_num_concurrent_instances()) { @@ -500,6 +549,22 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { for (size_t i = 0; i < numThreads; ++i) { threads.emplace_back(worker); } + + auto workersDone = [&] { return done == numThreads; }; + + // If we get an exception, try to wait for all threads to stop + // gracefully. + cv.wait(guard, [&] { return workersDone() || first_exception; }); + cv.wait_for( + guard, + std::chrono::seconds(FLAGS_caffe2_plan_executor_exception_timeout), + [&] { return workersDone(); }); + if (!workersDone() && first_exception) { + LOG(ERROR) << "failed to stop concurrent workers after exception: " + << first_exception.what(); + ExceptionWrapperTerminate(std::move(first_exception)); + } + for (auto& thread : threads) { thread.join(); } diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc index 86f145d72a09..bfd83968b4e2 100644 --- a/caffe2/core/plan_executor_test.cc +++ b/caffe2/core/plan_executor_test.cc @@ -18,6 +18,35 @@ static std::atomic cancelCount{0}; static std::atomic stuckRun{false}; } // namespace +class StuckBlockingOp final : public Operator { + public: + StuckBlockingOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + // StuckBlockingOp runs and notifies ErrorOp. + stuckRun = true; + + while (!cancelled_) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + return true; + } + + void Cancel() override { + LOG(INFO) << "cancelled StuckBlockingOp."; + cancelCount += 1; + cancelled_ = true; + } + + private: + std::atomic cancelled_{false}; +}; + +REGISTER_CPU_OPERATOR(StuckBlocking, StuckBlockingOp); +OPERATOR_SCHEMA(StuckBlocking).NumInputs(0).NumOutputs(0); + class StuckAsyncOp final : public Operator { public: StuckAsyncOp(const OperatorDef& operator_def, Workspace* ws) @@ -55,7 +84,7 @@ class ErrorOp final : public Operator { : Operator(operator_def, ws) {} bool RunOnDevice() override { - // Wait for StuckAsyncOp to run first. + // Wait for StuckAsyncOp or StuckBlockingOp to run first. while (!stuckRun) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } @@ -67,6 +96,29 @@ class ErrorOp final : public Operator { REGISTER_CPU_OPERATOR(Error, ErrorOp); OPERATOR_SCHEMA(Error).NumInputs(0).NumOutputs(0); +static std::atomic blockingErrorRuns{0}; +class BlockingErrorOp final : public Operator { + public: + BlockingErrorOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + // First n op executions should block and then start throwing errors. + if (blockingErrorRuns.fetch_sub(1) >= 1) { + LOG(INFO) << "blocking"; + while (true) { + std::this_thread::sleep_for(std::chrono::hours(10)); + } + } else { + LOG(INFO) << "throwing"; + throw TestError(); + } + } +}; + +REGISTER_CPU_OPERATOR(BlockingError, BlockingErrorOp); +OPERATOR_SCHEMA(BlockingError).NumInputs(0).NumOutputs(0); + PlanDef parallelErrorPlan() { PlanDef plan_def; @@ -100,11 +152,46 @@ PlanDef parallelErrorPlan() { return plan_def; } +PlanDef parallelErrorPlanWithCancellableStuckNet() { + // Set a plan with two nets: one stuck net with blocking operator that never + // returns; one error net with error op that throws. + PlanDef plan_def; + + auto* stuck_blocking_net = plan_def.add_network(); + stuck_blocking_net->set_name("stuck_blocking_net"); + { + auto* op = stuck_blocking_net->add_op(); + op->set_type("StuckBlocking"); + } + + auto* error_net = plan_def.add_network(); + error_net->set_name("error_net"); + { + auto* op = error_net->add_op(); + op->set_type("Error"); + } + + auto* execution_step = plan_def.add_execution_step(); + execution_step->set_concurrent_substeps(true); + { + auto* substep = execution_step->add_substep(); + substep->add_network(stuck_blocking_net->name()); + } + { + auto* substep = execution_step->add_substep(); + substep->add_network(error_net->name()); + } + + return plan_def; +} + struct HandleExecutorThreadExceptionsGuard { - HandleExecutorThreadExceptionsGuard() { + HandleExecutorThreadExceptionsGuard(int timeout = 60) { globalInit({ "caffe2", "--caffe2_handle_executor_threads_exceptions=1", + "--caffe2_plan_executor_exception_timeout=" + + caffe2::to_string(timeout), }); } @@ -139,6 +226,48 @@ TEST(PlanExecutorTest, ErrorAsyncPlan) { ASSERT_EQ(cancelCount, 1); } +// death tests not supported on mobile +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) +TEST(PlanExecutorTest, BlockingErrorPlan) { + // TSAN doesn't play nicely with death tests +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) + return; +#endif +#endif + + ASSERT_DEATH( + [] { + HandleExecutorThreadExceptionsGuard guard(/*timeout=*/1); + + PlanDef plan_def; + + std::string plan_def_template = R"DOC( + network { + name: "net" + op { + type: "BlockingError" + } + } + execution_step { + num_concurrent_instances: 2 + substep { + network: "net" + } + } + )DOC"; + + CAFFE_ENFORCE( + TextFormat::ParseFromString(plan_def_template, &plan_def)); + Workspace ws; + blockingErrorRuns = 1; + ws.RunPlan(plan_def); + FAIL() << "shouldn't have reached this point"; + }(), + "failed to stop concurrent workers after exception: test error"); +} +#endif + } // namespace caffe2 #endif diff --git a/caffe2/operators/channel_shuffle_op.cu b/caffe2/operators/channel_shuffle_op.cu index 2c5a4e5e7f9f..34f07afb6dd9 100644 --- a/caffe2/operators/channel_shuffle_op.cu +++ b/caffe2/operators/channel_shuffle_op.cu @@ -33,7 +33,7 @@ __global__ void ChannelShuffleNCHWKernel( template __global__ void -ChannelShuffleNHWCKernel(const int G, const int K, const float* X, float* Y) { +ChannelShuffleNHWCKernel(const int G, const int K, const T* X, T* Y) { __shared__ T sdata[kSharedSize]; const int C = G * K; const int offset = blockIdx.x * C; diff --git a/caffe2/opt/shape_info.cc b/caffe2/opt/shape_info.cc index dfcdeb0356bd..7e3ac1b15dc9 100644 --- a/caffe2/opt/shape_info.cc +++ b/caffe2/opt/shape_info.cc @@ -2,6 +2,7 @@ #include "caffe2/core/operator.h" #include "caffe2/core/tensor_int8.h" #include "caffe2/utils/string_utils.h" +#include namespace caffe2 { diff --git a/caffe2/python/hypothesis_test_util.py b/caffe2/python/hypothesis_test_util.py index 2000e269969e..404ba7f5c1a4 100644 --- a/caffe2/python/hypothesis_test_util.py +++ b/caffe2/python/hypothesis_test_util.py @@ -55,11 +55,7 @@ def is_sandcastle(): - if os.getenv('SANDCASTLE') == '1': - return True - elif os.getenv('TW_JOB_USER') == 'sandcastle': - return True - return False + return os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle' def is_travis(): diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 8117a3a63668..f5bce396054b 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -384,20 +384,14 @@ Collective functions .. autofunction:: broadcast -.. autofunction:: broadcast_object_list - .. autofunction:: all_reduce .. autofunction:: reduce .. autofunction:: all_gather -.. autofunction:: all_gather_object - .. autofunction:: gather -.. autofunction:: gather_object - .. autofunction:: scatter .. autofunction:: reduce_scatter diff --git a/docs/source/fft.rst b/docs/source/fft.rst index ab50bd271d32..ef728ea6579e 100644 --- a/docs/source/fft.rst +++ b/docs/source/fft.rst @@ -16,8 +16,8 @@ conflicts with the :func:`torch.fft` function. .. currentmodule:: torch.fft -Functions ---------- +Fast Fourier Transforms +----------------------- .. autofunction:: fft .. autofunction:: ifft @@ -29,3 +29,11 @@ Functions .. autofunction:: irfftn .. autofunction:: hfft .. autofunction:: ihfft + +Helper Functions +---------------- + +.. autofunction:: fftfreq +.. autofunction:: rfftfreq +.. autofunction:: fftshift +.. autofunction:: ifftshift diff --git a/docs/source/name_inference.rst b/docs/source/name_inference.rst index ccbb8c0c54d3..2606c82280b8 100644 --- a/docs/source/name_inference.rst +++ b/docs/source/name_inference.rst @@ -151,6 +151,7 @@ If you don't see an operation listed here, but it would help your use case, plea ":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc` ":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc` ":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.nanmedian`, :func:`torch.nanmedian`",:ref:`removes_dimensions-doc` ":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc` ":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc` ":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc` diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 8d195c04037c..d8144227fdaf 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -284,6 +284,16 @@ Vision Layers nn.UpsamplingNearest2d nn.UpsamplingBilinear2d +Shuffle Layers +---------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + nn.ChannelShuffle + DataParallel Layers (multi-GPU, distributed) -------------------------------------------- diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index d455f76b8c45..a713fd98f75d 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -214,80 +214,278 @@ proper thread locking code to ensure the hooks are thread safe. .. _complex_autograd-doc: Autograd for Complex Numbers -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +---------------------------- + +The short version: + +- When you use PyTorch to differentiate any function :math:`f(z)` with complex domain and/or codomain, + the gradients are computed under the assumption that the function is a part of a larger real-valued + loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}` + (note the conjugation of z), which is precisely the direction of the step + you should take in gradient descent. Thus, all the existing optimizers work out of + the box with complex parameters. +- This convention matches TensorFlow's convention for complex + differentiation, but is different from JAX (which computes + :math:`\frac{\partial L}{\partial z}`). +- If you have a real-to-real function which internally uses complex + operations, the convention here doesn't matter: you will always get + the same result that you would have gotten if it had been implemented + with only real operations. + +If you are curious about the mathematical details, or want to know how +to define complex derivatives in PyTorch, read on. + +What are complex derivatives? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -**What notion of complex derivative does PyTorch use?** -******************************************************* +The mathematical definition of complex-differentiability takes the +limit definition of a derivative and generalizes it to operate on +complex numbers. For a function :math:`f: ℂ → ℂ`, we can write: -PyTorch follows `JAX's `_ -convention for autograd for Complex Numbers. + .. math:: + f'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h} + +In order for this limit to exist, not only must :math:`u` and :math:`v` must be +real differentiable (as above), but :math:`f` must also satisfy the Cauchy-Riemann `equations +`_. In +other words: the limit computed for real and imaginary steps (:math:`h`) +must be equal. This is a more restrictive condition. + +The complex differentiable functions are commonly known as holomorphic +functions. They are well behaved, have all the nice properties that +you've seen from real differentiable functions, but are practically of no +use in the optimization world. For optimization problems, only real valued objective +functions are used in the research community since complex numbers are not part of any +ordered field and so having complex valued loss does not make much sense. + +It also turns out that no interesting real-valued objective fulfill the +Cauchy-Riemann equations. So the theory with homomorphic function cannot be +used for optimization and most people therefore use the Wirtinger calculus. + +Wirtinger Calculus comes in picture ... +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +So, we have this great theory of complex differentiability and +holomorphic functions, and we can’t use any of it at all, because many +of the commonly used functions are not holomorphic. What’s a poor +mathematician to do? Well, Wirtinger observed that even if :math:`f(z)` +isn’t holomorphic, one could rewrite it as a two variable function +:math:`f(z, z*)` which is always holomorphic. This is because real and +imaginary of the components of :math:`z` can be expressed in terms of +:math:`z` and :math:`z^*` as: -Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v -which compute the real and imaginary parts of the function: + .. math:: + \begin{aligned} + Re(z) &= \frac {z + z^*}{2} \\ + Im(z) &= \frac {z - z^*}{2j} + \end{aligned} + +Wirtinger calculus suggests to study :math:`f(z, z^*)` instead, which is +guaranteed to be holomorphic if :math:`f` was real differentiable (another +way to think of it is as a change of coordinate system, from :math:`f(x, y)` +to :math:`f(z, z^*)`.) This function has partial derivatives +:math:`\frac{\partial }{\partial z}` and :math:`\frac{\partial}{\partial z^{*}}`. +We can use the chain rule to establish a +relationship between these partial derivatives and the partial +derivatives w.r.t., the real and imaginary components of :math:`z`. - .. code:: + .. math:: + \begin{aligned} + \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ + &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ + \\ + \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ + &= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}) + \end{aligned} + +From the above equations, we get: + + .. math:: + \begin{aligned} + \frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial z}{\partial y}) \\ + \frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial z}{\partial y}) + \end{aligned} - def F(z): - x, y = real(z), imag(z) - return u(x, y) + v(x, y) * 1j +which is the classic definition of Wirtinger calculus that you would find on `Wikipedia `_. -where :math:`1j` is a unit imaginary number. +There are a lot of beautiful consequences of this change. -We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent -vector :math:`c+dj \in C` as: +- For one, the Cauchy-Riemann equations translate into simply saying that :math:`\frac{\partial f}{\partial z^*} = 0` (that is to say, the function :math:`f` can be written + entirely in terms of :math:`z`, without making reference to :math:`z^*`). +- Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should + take while making variable update is given by :math:`\frac{\partial Loss}{\partial z^*}` (not :math:`\frac{\partial Loss}{\partial z}`). - .. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} +For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf -where +How is Wirtinger Calculus useful in optimization? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Researchers in audio and other fields, more commonly, use gradient +descent to optimize real valued loss functions with complex variables. +Typically, these people treat the real and imaginary values as separate +channels that can be updated. For a step size :math:`s/2` and loss +:math:`L`, we can write the following equations in :math:`ℝ^2`: .. math:: - J = \begin{bmatrix} - \frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\ - \frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\ + \begin{aligned} + x_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} \\ + y_{n+1} &= y_n - (s/2) * \frac{\partial L}{\partial y} + \end{aligned} + +How do these equations translate into complex space :math:`ℂ`? -This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication -with :math:`[1, 1j]^T` is used to identify the result as a complex number. + .. math:: + \begin{aligned} + z_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (s/2) * \frac{\partial L}{\partial y}) + &= z_n - s * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}) + &= z_n - s * \frac{\partial L}{\partial z^*} + \end{aligned} + +Something very interesting has happened: Wirtinger calculus tells us +that we can simplify the complex variable update formula above to only +refer to the conjugate Wirtinger derivative +:math:`\frac{\partial L}{\partial z^*}`, giving us exactly the step we take in optimization. + +Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivative +when you differentiate a function with a real valued loss. + +How does PyTorch compute the conjugate Wirtinger derivative? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Typically, our derivative formulas take in `grad_output` as an input, +representing the incoming Vector-Jacobian product that we’ve already +computed, aka, :math:`\frac{\partial L}{\partial s^*}`, where :math:`L` +is the loss of the entire computation (producing a real loss) and +:math:`s` is the output of our function. The goal here is to compute +:math:`\frac{\partial L}{\partial z^*}`, where :math:`z` is the input of +the function. It turns out that in the case of real loss, we can +get away with *only* calculating :math:`\frac{\partial L}{\partial z^*}`, +even though the chain rule implies that we also need to +have access to :math:`\frac{\partial L}{\partial z^*}`. If you want +to skip this derivation, look at the last equation in this section +and then skip to the next section. + +Let’s continue working with :math:`f: ℂ → ℂ` defined as +:math:`f(z) = f(x+yj) = u(x, y) + v(x, y)j`. As discussed above, +autograd’s gradient convention is centered around optimization for real +valued loss functions, so let’s assume :math:`f` is a part of larger +real valued loss function :math:`g`. Using chain rule, we can write: -We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as: + .. math:: + \frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*} + :label: [1] - .. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} +Now using Wirtinger derivative definition, we can write: -In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward -mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at -the `JAX docs `_ -to get explanation for the negative signs in the formula. + .. math:: + \begin{aligned} + \frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\ + \frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j) + \end{aligned} -**What happens if I call backward() on a complex scalar?** -******************************************************************************* +It should be noted here that since :math:`u` and :math:`v` are real +functions, and :math:`L` is real by our assumption that :math:`f` is a +part of a real valued function, we have: -The gradient for a complex function is computed assuming the input function is a holomorphic function. -This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom -(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number. -However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the -Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate -matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can -obtain that gradient using backward which is just a call to `vjp` with covector `1.0`. + .. math:: + (\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*} + :label: [2] -The net effect of this assumption is that the partial derivatives of the imaginary part of the function -(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar -(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards). +i.e., :math:`\frac{\partial L}{\partial s}` equals to :math:`grad\_output^*`. -For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly. +Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math:`\frac{\partial L}{\partial v}`, we get: -**How are the JVP and VJP defined for cross-domain functions?** -*************************************************************** + .. math:: + \begin{aligned} + \frac{\partial L}{\partial u} = 1/2 * (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) \\ + \frac{\partial L}{\partial v} = -1/2j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) + \end{aligned} + :label: [3] + +Substituting :eq:`[3]` in :eq:`[1]`, we get: + + .. math:: + \begin{aligned} + \frac{\partial L}{\partial z^*} &= 1/2 * (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1/2j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\ + &= \frac{\partial L}{\partial s} * 1/2 * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * 1/2 * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\ + &= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\ + &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ + \end{aligned} + +Using :eq:`[2]`, we get: + + .. math:: + \begin{aligned} + \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\ + &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\ + \end{aligned} + :label: [4] + +This last equation is the important one for writing your own gradients, +as it decomposes our derivative formula into a simpler one that is easy +to compute by hand. + +How can I write my own derivative formula for a complex function? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The above boxed equation gives us the general formula for all +derivatives on complex functions. However, we still need to +compute :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}`. +There are two ways you could do this: + + - The first way is to just use the definition of Wirtinger derivatives directly and calculate :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}` by + using :math:`\frac{\partial s}{\partial x}` and :math:`\frac{\partial s}{\partial y}` + (which you can compute in the normal way). + - The second way is to use the change of variables trick and rewrite :math:`f(z)` as a two variable function :math:`f(z, z^*)`, and compute + the conjugate Wirtinger derivatives by treating :math:`z` and :math:`z^*` as independent variables. This is often easier; for example, if the function in question is holomorphic, only :math:`z` will be used (and :math:`\frac{\partial s}{\partial z^*}` will be zero). + +Let's consider the function :math:`f(z = x + yj) = c * z = c * (x+yj)` as an example, where :math:`c \in ℝ`. + +Using the first way to compute the Wirtinger derivatives, we have. + +.. math:: + \begin{aligned} + \frac{\partial s}{\partial z} &= 1/2 * (\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j) \\ + &= 1/2 * (c - (c * 1j) * 1j) \\ + &= c \\ + \\ + \\ + \frac{\partial s}{\partial z^*} &= 1/2 * (\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j) \\ + &= 1/2 * (c + (c * 1j) * 1j) \\ + &= 0 \\ + \end{aligned} + +Using :eq:`[4]`, and `grad\_output = 1.0` (which is the default grad output value used when :func:`backward` is called on a scalar output in PyTorch), we get: + + .. math:: + \frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c + +Using the second way to compute Wirtinger derivatives, we directly get: + + .. math:: + \begin{aligned} + \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ + &= c \\ + \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ + &= 0 + \end{aligned} -Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity), -we use the formula given below for cross-domain functions. +And using :eq:`[4]` again, we get :math:`\frac{\partial L}{\partial z^*} = c`. As you can see, the second way involves lesser calculations, and comes +in more handy for faster calculations. -The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as: +What about cross-domain functions? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - .. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix} +Some functions map from complex inputs to real outputs, or vice versa. +These functions form a special case of :eq:`[4]`, which we can derive using the +chain rule: - .. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} + - For :math:`f: ℂ → ℝ`, we get: -The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as: + .. math:: + \frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}} - .. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\ + - For :math:`f: ℝ → ℂ`, we get: - .. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J + .. math:: + \frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}}) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index ef6d60599d7f..a4d2ac805a8f 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -436,6 +436,7 @@ view of a storage and defines numeric operations on it. .. automethod:: maximum .. automethod:: mean .. automethod:: median + .. automethod:: nanmedian .. automethod:: min .. automethod:: minimum .. automethod:: mm diff --git a/docs/source/torch.rst b/docs/source/torch.rst index bb267810cf41..1fc46780e6ac 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -355,6 +355,7 @@ Reduction Ops logsumexp mean median + nanmedian mode norm nansum diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index a2f843d78f72..46d32ed4a430 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -111,6 +111,11 @@ ("aten::_amp_non_finite_check_and_unscale_", datetime.date(9999, 1, 1)), ("aten::choose_qparams_optimized", datetime.date(2020, 10, 5)), ("aten::smooth_l1_loss_backward", datetime.date(2020, 10, 15)), + ("quantized::embedding_bag_byte", datetime.date(2020, 10, 15)), + ("quantized::embedding_bag_4bit", datetime.date(2020, 10, 15)), + ("quantized::embedding_byte", datetime.date(2020, 10, 15)), + ("quantized::embedding_bag_byte_rowwise_offsets", datetime.date(2020, 10, 15)), + ("quantized::embedding_bag_4bit_rowwise_offsets", datetime.date(2020, 10, 15)), ] diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 2b2cec70358d..36cbebd9d0f1 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -522,5 +522,232 @@ void testKernelSumMultipleAxes() { } } +// This test and the following ones testing Softmax only tests with dim set +// to one of the valid input dimensions. It does not test with dim=None +// because that is supposed to be deprecated. +void testKernelSoftmax2D() { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Tensor = aten::softmax(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i0 = 0; i0 < 5 + # CHECK-NEXT: for (int i1 = 0; i1 < 3 + # CHECK-NEXT: input1 + # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i0_2 = 0; i0_2 < 5 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: aten_softmax_exp + # CHECK: for (int i${other_dim}_3 = 0; i${other_dim}_3 < ${other_dim_size} + # CHECK: for (int i${softmax_dim}_3 = 0; i${softmax_dim}_3 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_4 = 0; i0_4 < 5 + # CHECK-NEXT: for (int i1_4 = 0; i1_4 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + auto other_dim = (softmax_dim + 1) % a.dim(); + + KernelScope kernel_scope; + TemplateEnv env; + env.d("dim", softmax_dim); + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + TemplateEnv ver_env; + ver_env.d("other_dim", other_dim); + ver_env.d("other_dim_size", a.sizes()[other_dim]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + auto ref = a.softmax(softmax_dim); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } +} + +void testKernelSoftmax3D() { + const auto graph_template = R"IR( + graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Tensor = aten::softmax(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i0 = 0; i0 < 3 + # CHECK-NEXT: for (int i1 = 0; i1 < 4 + # CHECK-NEXT: for (int i2 = 0; i2 < 5 + # CHECK-NEXT: input1 + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i0_2 = 0; i0_2 < 3 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 + # CHECK-NEXT: aten_softmax_exp + # CHECK: for (int i${dim1}_3 = 0; i${dim1}_3 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_3 = 0; i${dim2}_3 < ${dim2_size} + # CHECK: for (int i${softmax_dim}_3 = 0; i${softmax_dim}_3 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_4 = 0; i0_4 < 3 + # CHECK-NEXT: for (int i1_4 = 0; i1_4 < 4 + # CHECK-NEXT: for (int i2_4 = 0; i2_4 < 5 + # CHECK-NEXT: aten_softmax)IR"; + + for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (int i = 0; i < a.dim(); ++i) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + + KernelScope kernel_scope; + TemplateEnv env; + env.d("dim", softmax_dim); + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + + auto ref = a.softmax(softmax_dim); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } +} + +void testKernelSoftmax4D() { + const auto graph_template = R"IR( + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Tensor = aten::softmax(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i0 = 0; i0 < 2 + # CHECK-NEXT: for (int i1 = 0; i1 < 3 + # CHECK-NEXT: for (int i2 = 0; i2 < 2 + # CHECK-NEXT: for (int i3 = 0; i3 < 3 + # CHECK-NEXT: input1 + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i0_2 = 0; i0_2 < 2 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 + # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 + # CHECK-NEXT: aten_softmax_exp + # CHECK: for (int i${dim1}_3 = 0; i${dim1}_3 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_3 = 0; i${dim2}_3 < ${dim2_size} + # CHECK-NEXT: for (int i${dim3}_3 = 0; i${dim3}_3 < ${dim3_size} + # CHECK: for (int i${softmax_dim}_3 = 0; i${softmax_dim}_3 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_4 = 0; i0_4 < 2 + # CHECK-NEXT: for (int i1_4 = 0; i1_4 < 3 + # CHECK-NEXT: for (int i2_4 = 0; i2_4 < 2 + # CHECK-NEXT: for (int i3_4 = 0; i3_4 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (int i = 0; i < a.dim(); ++i) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + + KernelScope kernel_scope; + TemplateEnv env; + env.d("dim", softmax_dim); + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("dim3", other_dims[2]); + ver_env.d("dim3_size", a.sizes()[other_dims[2]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + auto ref = a.softmax(softmax_dim); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 60c8fb1d62c0..8402f6519cfb 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -2922,5 +2922,289 @@ void testDetectInlineRankMismatch() { "Placeholder indexed access is inconsistent with its rank"); } +void testCacheReadsSimple() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 3); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* j_loop = l.getLoopStmtsFor(B)[1]; + l.cacheAccesses(A->buf(), "A_local", j_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + + // just this once: verify the whole thing. + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A, int, {64, 64}); +#CHECK: for (int i +#CHECK: for (int j +#CHECK: A[ +#CHECK: } +#CHECK: } +#CHECK: for (int i_1 +#CHECK: Allocate(A_local, int, {1, 10}); +#CHECK: for (int j_1 +#CHECK: A_local[j_1] = A[ +#CHECK: } +#CHECK: for (int j_2 +#CHECK: B[10 * i_1 + j_2] = A_local[j_2]; +#CHECK: } +#CHECK: Free(A_local); +#CHECK: } +#CHECK: for (int i_2 +#CHECK: for (int j_3 +#CHECK: C[ +#CHECK: } +#CHECK: } +#CHECK: Free(A); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 3); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +void testCacheReadsOuter() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 40) + A->call(i + 31, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* i_loop = l.getLoopStmtsFor(B)[0]; + l.cacheAccesses(A->buf(), "A_local", i_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {21, 11}); +#CHECK: A_local[j_1 + 11 * i_1] = +#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +void testCacheReadsInternal() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 40) + A->call(i + 31, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* j_loop = l.getLoopStmtsFor(B)[1]; + l.cacheAccesses(A->buf(), "A_local", j_loop); + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {2, 11}); +#CHECK: A_local[j_1 + 11 * i_2] = +#CHECK: B[10 * i_1 + j_2] = (A_local[j_2]) + (A_local[j_2 + 12]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +void testCacheReadsInner() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + // note im changing the offset of the first arg of the first call to A. + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 34, j + 40) + A->call(i + 30, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* body = l.getLoopBodyFor(B); + l.cacheAccesses(A->buf(), "A_local", body); + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {5, 2}); +#CHECK: A_local[2 * i_2 + j_2] = +#CHECK: B[10 * i_1 + j_1] = (A_local[8]) + (A_local[1]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +void testCacheWritesSimple() { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 40) + A->call(i + 31, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* a_loop = l.getLoopStmtsFor(A)[1]; + l.cacheAccesses(A->buf(), "A_local", a_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {1, 64}); +#CHECK: for (int j = 0; j < 64 +#CHECK: A_local[j] = i * j; +#CHECK: for (int j_1 = 0; j_1 < 64 +#CHECK: A[64 * i + j_1] = A_local[ +#CHECK: Free(A_local); +#CHECK-NOT: A_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 8436388f0d6b..0555c310cc5b 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -3,16 +3,18 @@ #include #include #include -#include "test/cpp/tensorexpr/test_base.h" -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "torch/csrc/jit/tensorexpr/analysis.h" -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/loopnest.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace torch { namespace jit { @@ -1451,5 +1453,397 @@ void testReduceInlineReducerInternal() { ASSERT_GT(oss1.str().size(), oss2.str().size()); } +void testReductionCacheAccessesOuter() { + KernelScope kernel_scope; + + int L = 4; + int N = 3; + int M = 2; + + Placeholder a(BufHandle("a", {L, N, M}, kFloat)); + Placeholder b(BufHandle("b", {L, N, M}, kFloat)); + + Tensor* c = Compute( + "scale", + {{L, "l2"}, {N, "n1"}, {M, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + + Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(d->buf(), "d_local", d_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local, float, {1}); +#CHECK: sum[l1] = 0 +#CHECK: d_local[0] = 0 +#CHECK: for (int n1 +#CHECK: for (int m1 +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: sum[l1] = (sum[l1]) + (d_local[0]) +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +void testReductionCacheAccessesInner() { + KernelScope kernel_scope; + + int L = 4; + int N = 3; + int M = 2; + + Placeholder a(BufHandle("a", {L, N, M}, kFloat)); + Placeholder b(BufHandle("b", {L, N, M}, kFloat)); + + Tensor* c = Compute( + "scale", + {{L, "l2"}, {N, "n1"}, {M, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + + Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[2]; + l.cacheAccesses(d->buf(), "d_local", d_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1] = 0 +#CHECK: for (int n1 +#CHECK: Allocate(d_local, float, {1}); +#CHECK: d_local[0] = 0 +#CHECK: for (int m1 +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: sum[l1] = (sum[l1]) + (d_local[0]) +#CHECK: Free(d_local); +#CHECK: } +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +void testReductionCacheBodyAccess() { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(c->buf(), "scale_local", d_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(scale_local, float, {1, 32, 12}); +#CHECK: for (int j = 0; j < 32; j++) { +#CHECK: for (int k = 0; k < 12; k++) { +#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j]; +#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]); +#CHECK: Free(scale_local); +#CHECK: scale_1[l] = (b[l]) * (sum[l]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +void testReductionCacheConsumerAccess() { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + For* outer; + For* inner; + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + + Stmt* e_loop = l.getLoopStmtsFor(e)[1]; + l.cacheAccesses(d->buf(), "sum_local", e_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1] = (sum[l1]) + (scale[ +#CHECK: Allocate(sum_local, float, {4}); +#CHECK: for (int i = 0; i < 4 +#CHECK: sum_local[i] = sum[i + 4 * l_outer]; +#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +void testReductionSplitCacheConsumerAccess() { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + For* outer; + For* inner; + + // Split outer reduction axis. + l.splitWithMask(l.getLoopStmtsFor(d)[0], 4, &outer, &inner); + + // Split reduction consumer. + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + + l.cacheAccesses(d->buf(), "sum_local", inner); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + // reduction changes but cache does not. + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]); +#CHECK: Allocate(sum_local, float, {4}); +#CHECK: for (int i = 0; i < 4 +#CHECK: sum_local[i] = sum[i + 4 * l_outer]; +#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +void testReductionReorderCacheConsumerAccess() { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + For* outer; + For* inner; + + // reorder outer reduction axes. + auto loops = l.getLoopStmtsFor(d); + l.reorderAxis(loops[0], loops[1]); + + // Split reduction consumer. + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + + l.cacheAccesses(d->buf(), "sum_local", inner); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + // neither reduction body not cache changes. + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]); +#CHECK: Allocate(sum_local, float, {4}); +#CHECK: for (int i = 0; i < 4 +#CHECK: sum_local[i] = sum[i + 4 * l_outer]; +#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +void testReductionRfactorCacheTempOuter() { + KernelScope kernel_scope; + + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + Placeholder b(BufHandle("B", {m, n, k}, kFloat)); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); + LoopNest loop({c}); + auto reduces = NodeFinder::find(loop.root_stmt()); + loop.rfactor(reduces[0], reduces[0]->reduce_args()[1]); + + reduces = NodeFinder::find(loop.root_stmt()); + std::vector loops = NodeFinder::find(loop.root_stmt()); + loop.cacheAccesses(reduces[0]->accumulator(), "tmp2", loops[2]); + loop.prepareForCodegen(); + Stmt* s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(tmp_buf, float, {n}); +#CHECK: for (int a = 0; a < m +#CHECK: Allocate(tmp2, float, {n}); +#CHECK: for (int i = 0; i < n +#CHECK: tmp2[i] = 0 +#CHECK: } +#CHECK: for (int b = 0; b < n +#CHECK: for (int c +#CHECK: tmp2[b] = (tmp2[b]) + (B[ +#CHECK: } +#CHECK: } +#CHECK: for (int i = 0; i < n +#CHECK: tmp_buf[i] = (tmp_buf[i]) + (tmp2[i]); +#CHECK: } +#CHECK: Free(tmp2); +#CHECK-NOT: tmp2 + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +void testReductionRfactorCacheTempInner() { + KernelScope kernel_scope; + + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + Placeholder b(BufHandle("B", {m, n, k}, kFloat)); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); + LoopNest loop({c}); + auto reduces = NodeFinder::find(loop.root_stmt()); + loop.rfactor(reduces[0], reduces[0]->reduce_args()[1]); + + reduces = NodeFinder::find(loop.root_stmt()); + std::vector loops = NodeFinder::find(loop.root_stmt()); + loop.cacheAccesses(reduces[0]->accumulator(), "tmp2", loops[3]); + loop.prepareForCodegen(); + Stmt* s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(tmp_buf, float, {n}); +#CHECK: for (int a = 0; a < m +#CHECK: for (int b = 0; b < n +#CHECK: Allocate(tmp2, float, {1}); +#CHECK: tmp2[0] = 0 +#CHECK: for (int c +#CHECK: tmp2[0] = (tmp2[0]) + (B[ +#CHECK: } +#CHECK: tmp_buf[b] = (tmp_buf[b]) + (tmp2[0]); +#CHECK: Free(tmp2); +#CHECK-NOT: tmp2 + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index c32183aaa042..d2cd1efc1a3d 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -116,6 +116,14 @@ namespace jit { _(ReduceInlineReduction) \ _(ReduceInlineConsumer) \ _(ReduceInlineReducerInternal) \ + _(ReductionCacheAccessesOuter) \ + _(ReductionCacheAccessesInner) \ + _(ReductionCacheBodyAccess) \ + _(ReductionCacheConsumerAccess) \ + _(ReductionSplitCacheConsumerAccess) \ + _(ReductionReorderCacheConsumerAccess) \ + _(ReductionRfactorCacheTempOuter) \ + _(ReductionRfactorCacheTempInner) \ _(TypeTest01) \ _(TypePropagation) \ _(Cond01) \ @@ -307,7 +315,6 @@ namespace jit { _(LoopNestComputeAt_1) \ _(LoopNestComputeAt_2) \ _(LoopNestComputeAt_3) \ - _(LoopNestComputeAt_4) \ _(LoopNestReorderAxis1) \ _(LoopNestReorderPartialAxes) \ _(LoopNestReorderInternalAxis) \ @@ -334,6 +341,11 @@ namespace jit { _(NormalizeOnNestedInnerLoop) \ _(NormalizeAndSplitWithTail) \ _(DetectInlineRankMismatch) \ + _(CacheReadsSimple) \ + _(CacheReadsOuter) \ + _(CacheReadsInternal) \ + _(CacheReadsInner) \ + _(CacheWritesSimple) \ _(Kernel_1) \ _(Kernel_2) \ _(Kernel_3) \ @@ -341,6 +353,9 @@ namespace jit { _(KernelSumAllAxes) \ _(KernelSumOneAxis) \ _(KernelSumMultipleAxes) \ + _(KernelSoftmax2D) \ + _(KernelSoftmax3D) \ + _(KernelSoftmax4D) \ _(FuserPass_1) \ _(FuserPass_2) \ _(FuserPass_3) \ diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 9d0c19bef7b3..56fbb5ba3f88 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -33,7 +33,7 @@ create_device from torch.testing._internal.common_utils import TestCase, load_tests, run_tests, \ - retry_on_connect_failures, ADDRESS_IN_USE, CONNECT_TIMEOUT, TEST_WITH_TSAN + retry_on_connect_failures, ADDRESS_IN_USE, CONNECT_TIMEOUT, TEST_WITH_TSAN, slowTest # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -296,6 +296,8 @@ def _test_numkeys_delkeys(self, fs): self.assertEqual(b"value1", fs.get("key1")) self.assertEqual(b"value2", fs.get("key4")) + # https://github.com/pytorch/pytorch/issues/46064 <- takes 5+ min to finish + @slowTest def test_numkeys_delkeys(self): self._test_numkeys_delkeys(self._create_store()) @@ -1619,6 +1621,10 @@ def test_init_no_gpus(self): c10d.ProcessGroupNCCL(store, self.rank, self.world_size) +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) class ProcessGroupNCCLTest(TestCase): MAIN_PROCESS_RANK = 0 @@ -3828,7 +3834,6 @@ def test_multi_limit_multi_dtype(self): self.assertEqual([[0], [1], [2, 4], [3, 5]], result) -@skip_if_rocm @unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") class NcclErrorHandlingTest(MultiProcessTestCase): def setUp(self): @@ -3993,7 +3998,7 @@ def _wait_for_comm_abort(self, process_group): return else: raise e - time.sleep(0.1) + time.sleep(1) @requires_nccl() @skip_if_lt_x_gpu(3) diff --git a/test/print_test_stats.py b/test/print_test_stats.py index d1ccc3d36cc1..a1d2dc0b6a67 100755 --- a/test/print_test_stats.py +++ b/test/print_test_stats.py @@ -122,7 +122,6 @@ def send_report(reports): ), }, ) - print("Scribe report status: {}".format(r.text)) r.raise_for_status() def positive_integer(value): @@ -131,6 +130,12 @@ def positive_integer(value): raise argparse.ArgumentTypeError(f"{value} is not a natural number") return parsed +def positive_float(value): + parsed = float(value) + if parsed <= 0.0: + raise argparse.ArgumentTypeError(f"{value} is not a positive rational number") + return parsed + if __name__ == '__main__': import argparse import sys @@ -145,6 +150,13 @@ def positive_integer(value): metavar="N", help="how many longest tests to show for each class", ) + parser.add_argument( + "--class-print-threshold", + type=positive_float, + default=1.0, + metavar="N", + help="Minimal total time to warrant class report", + ) parser.add_argument( "--longest-of-run", type=positive_integer, @@ -169,7 +181,8 @@ def positive_integer(value): total_time = 0 for name in sorted(reports.keys()): test_suite = reports[name] - test_suite.print_report(args.longest_of_class) + if test_suite.total_time >= args.class_print_threshold: + test_suite.print_report(args.longest_of_class) total_time += test_suite.total_time longest_tests.extend(test_suite.test_cases) longest_tests = sorted(longest_tests, key=lambda x: x.time)[-args.longest_of_run:] diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 835dc6bf3083..ee194b9549e6 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -6,11 +6,6 @@ import torch.nn.intrinsic.quantized as nniq import torch.multiprocessing as mp -# symbolic trace -from torch.fx import symbolic_trace - -from torch.fx.symbolic_trace import Tracer - # graph mode quantization based on fx from torch.quantization import ( QuantType, @@ -175,10 +170,9 @@ def forward(self, x): return F.linear(x, self.weight) m = M(torch.rand(1, 1)).eval() - original = symbolic_trace(m) qconfig = default_dynamic_qconfig qconfig_dict = {'': qconfig} - prepared = prepare_fx(original, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict) quantized = convert_fx(prepared, debug=True) qparams = (quantized._scale_0, quantized._zero_point_0) weight_obs = qconfig.weight() @@ -224,7 +218,6 @@ def forward(self, x): if weight_prepack_node: node_occurrence[weight_prepack_node] = 0 m = ModuleClass(*module_constructor_inputs).eval() - m = symbolic_trace(m) qconfig_dict = {"": float16_dynamic_qconfig} m = prepare_fx(m, qconfig_dict) m = convert_fx(m, debug=debug) @@ -259,9 +252,6 @@ def forward(self, x): device = torch.device('cuda:0') model.to(device) - # symbolically trace - model = symbolic_trace(model) - # QAT prepare model = prepare_qat_fx(model, qconfig_dict) @@ -287,7 +277,6 @@ def forward(self, x): return self.conv(x) model = M().eval() - model = symbolic_trace(model) qconfig_dict = {'': default_qconfig} prepared = prepare_fx( model, qconfig_dict, inplace=False) @@ -316,7 +305,7 @@ def forward(self, x): return {"output": self.conv(x["input"])} dict_input = {"input": torch.randn(1, 1, 1, 1)} - m = symbolic_trace(M()).eval() + m = M().eval() qconfig_dict = {"": default_qconfig} m = prepare_fx(m, qconfig_dict) m(dict_input) @@ -332,12 +321,6 @@ def __init__(self): def forward(self, x): return self.conv(x) - class CustomTracer(Tracer): - def is_leaf_module(self, m, module_qualified_name): - return (m.__module__.startswith('torch.nn') and - not isinstance(m, torch.nn.Sequential)) or \ - isinstance(m, StandaloneModule) - class M(torch.nn.Module): def __init__(self): super().__init__() @@ -362,17 +345,16 @@ def forward(self, x): data = torch.randn(1, 1, 1, 1) # instantiate M and RefM and align the parameters - original_m = M() - original_ref_m = RefM() + original_m = M().eval() + original_ref_m = RefM().eval() original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) - m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m)).eval() qconfig_dict = {'': default_qconfig, 'standalone_module_name': ['standalone']} # check prepared model - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(original_m, qconfig_dict) # calibration m(data) # input and output of first conv, observer for standalone module @@ -406,8 +388,7 @@ def forward(self, x): res = m(data) # quantize the reference model - ref_m = symbolic_trace(original_ref_m).eval() - ref_m = prepare_fx(ref_m, qconfig_dict) + ref_m = prepare_fx(original_ref_m, qconfig_dict) ref_m(data) ref_m = convert_fx(ref_m) ref_res = ref_m(data) @@ -427,7 +408,6 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"": default_qconfig, "module_name": [("conv2", None)]} m = prepare_fx(m, qconfig_dict) @@ -457,7 +437,6 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) @@ -482,7 +461,6 @@ def forward(self, x, y): return x + y m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) @@ -510,7 +488,6 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) @@ -547,7 +524,6 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) global_qconfig = default_qconfig object_type_qconfig = default_dynamic_qconfig module_name_regex_qconfig = float16_dynamic_qconfig @@ -574,7 +550,6 @@ def forward(self, x): return self.avg_pool(x) m = M().eval() - m = symbolic_trace(m) qconfig_dict = {'': default_qconfig} m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) @@ -587,13 +562,9 @@ def forward(self, x): @skipIfNoFBGEMM def test_qat_and_script(self): - - model = LinearModelWithSubmodule() + model = LinearModelWithSubmodule().train() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)} - - # symbolically trace - model = symbolic_trace(model) model = prepare_qat_fx(model, qconfig_dict) # ensure scripting works @@ -629,8 +600,6 @@ def test_save_observer_state_dict(self): orig = LinearModelWithSubmodule().eval() model = orig qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')} - # symbolically trace - model = symbolic_trace(model) model = prepare_fx(model, qconfig_dict) # run it through input @@ -647,7 +616,6 @@ def test_save_observer_state_dict(self): # Load the stats into new model model_2 = orig - model_2 = symbolic_trace(model_2) model_2 = prepare_fx(model_2, qconfig_dict) loaded_dict = torch.load(b) @@ -659,6 +627,7 @@ def test_save_observer_state_dict(self): self.assertEqual(quant(x), quant_2(x)) @skipIfNoFBGEMM + @unittest.skip("Fix in next PR, will need to change API") def test_custom_module_class(self): class CustomModule(torch.nn.Module): def __init__(self): @@ -739,8 +708,8 @@ def forward(self, x): data = torch.randn(1, 1, 1, 1) # instantiate M and RefM and align the parameters - original_m = M() - original_ref_m = RefM() + original_m = M().eval() + original_ref_m = RefM().eval() original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) @@ -762,7 +731,7 @@ def is_leaf_module(self, m, module_qualified_name): register_observed_custom_module_mapping(CustomModule, ObservedCustomModule) register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule) - m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m)).eval() + m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m)) qconfig_dict = {'': default_qconfig} # check prepared model m = prepare_fx(m, qconfig_dict) @@ -785,8 +754,7 @@ def is_leaf_module(self, m, module_qualified_name): res = m(data) # quantize the reference model - ref_m = symbolic_trace(original_ref_m).eval() - ref_m = prepare_fx(ref_m, qconfig_dict) + ref_m = prepare_fx(original_ref_m, qconfig_dict) ref_m(data) ref_m = convert_fx(ref_m) ref_res = ref_m(data) @@ -1351,10 +1319,9 @@ def forward(self, x): # This model is not executable since we just put all ops # in the same forward m = M().eval() - original = symbolic_trace(m) # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(original, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict) # not runnable quantized = convert_fx(prepared) @@ -1440,10 +1407,9 @@ def forward(self, x): # This model is not executable since we just put all ops # in the same forward m = M().eval() - original = symbolic_trace(m) # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(original, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict) # not runnable quantized = convert_fx(prepared) @@ -1492,12 +1458,11 @@ def _test_model_impl( qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} - graph_module = symbolic_trace(model) # print('graph module:', graph_module.src) - script = torch.jit.script(graph_module) + script = torch.jit.script(model) # make sure graph module and script module are both runanble - original_out = graph_module(input_value) + original_out = model(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) self.assertEqual( @@ -1508,7 +1473,7 @@ def _test_model_impl( if mode != 'static': model.train() - prepared = prepare_fx(graph_module, qconfig_dict) + prepared = prepare_fx(model, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index a0fad9b80e89..44bbf116b4dc 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -3039,14 +3039,14 @@ def __init__(self, weights): self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, - sparse=False, + sparse=True, _weight=weights, mode='sum') self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, - sparse=False, + sparse=True, _weight=weights, mode='sum') @@ -3077,6 +3077,7 @@ def forward(self, indices1, offsets1, indices2, offsets2): FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \ .check_next("quantized::embedding_bag_byte_rowwise_offsets") \ .run(m.graph) + m(*dummy_inputs) diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index a1fbc308dfde..5780f3ffcbdf 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -747,7 +747,7 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): module_out = qemb(indices) # Call the qembedding operator directly - ref = torch.ops.quantized.embedding_byte(w_packed, indices, sparse=False) + ref = torch.ops.quantized.embedding_byte(w_packed, indices, pruned_weights=False) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False) diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 60dd789af367..99e06db982aa 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -3067,7 +3067,7 @@ def test_embedding_byte(self, num_embeddings, embedding_dim): low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) packed_weight = prepack_op(qweight) - qresult = quant_op(packed_weight, indices, sparse=False) + qresult = quant_op(packed_weight, indices, pruned_weights=False) ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3) diff --git a/test/test_autograd.py b/test/test_autograd.py index 3c0d0a9a2e8e..91da9e79b885 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6082,7 +6082,7 @@ def foo(a): class TestAutogradDeviceType(TestCase): def test_min_max_median_backprops_to_all_values(self, device): - for f in [torch.min, torch.max, torch.median]: + for f in [torch.min, torch.max, torch.median, torch.nanmedian]: x1 = torch.tensor([1., 0., 1., 0., 1., 0.], device=device, requires_grad=True) x2 = torch.tensor([float('nan'), float('nan'), float('nan')], requires_grad=True) for x in [x1, x2]: diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 9074cc3c0b7d..67a9c8477e8b 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -19,7 +19,7 @@ from torch.utils.data.dataset import random_split from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, - IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, + IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, load_tests, TEST_WITH_ROCM, TEST_WITH_TSAN, IS_SANDCASTLE) try: @@ -1530,8 +1530,10 @@ def test_partial_workers(self): pin_memory_thread.join(JOIN_TIMEOUT) self.assertFalse(pin_memory_thread.is_alive()) + # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065 @skipIfRocm @unittest.skipIf(not HAS_PSUTIL, "psutil not found") + @slowTest def test_proper_exit(self): (r'''There might be ConnectionResetError or leaked semaphore warning ''' r'''(due to dirty process exit), but they are all safe to ignore''') diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index ac9f054d38c8..b0283f4f8107 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -3,7 +3,7 @@ import torch -from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, skipIfRocm, TEST_WITH_ROCM +from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed from test_jit import JitTestCase, RUN_CUDA @@ -561,7 +561,6 @@ def t(x: torch.Tensor, y: torch.Tensor): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_binary_ops_permutation(self): # note that num_dim is exclusive from len(x), so we are not reducing # to single element (codegen limitation at this moment) @@ -603,7 +602,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_reduction(self): for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]): # note that num_dim is exclusive from len(x), so we are not reducing @@ -617,7 +615,6 @@ def test_reduction(self): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_reduction_permutation(self): x = [7, 8, 12] # note that num_dim is exclusive from len(x), so we are not reducing @@ -668,7 +665,6 @@ def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_reduction_dtype(self): def t(x: torch.Tensor): o = torch.mul(x, 1.0) @@ -687,7 +683,6 @@ def t(x: torch.Tensor): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_reduction_half(self): def t(x: torch.Tensor): o = torch.mul(x, 1.0) @@ -706,7 +701,6 @@ def t(x: torch.Tensor): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_pw_single_reduction_partition(self): sizes = [8, 8, 8] dtype = torch.float @@ -731,7 +725,6 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm def test_single_reduction_broadcast(self): dtype = torch.float device = "cuda" @@ -802,5 +795,5 @@ def test_register_fuser(self): if __name__ == '__main__': - if not TEST_WITH_ROCM and GRAPH_EXECUTOR != ProfilingMode.PROFILING: + if GRAPH_EXECUTOR != ProfilingMode.PROFILING: run_tests() diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 5cb43cbe8079..257aa1d0b143 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1,5 +1,3 @@ -from collections import defaultdict - import operator import unittest import contextlib @@ -74,36 +72,6 @@ def tearDown(self): torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) - def assertAllFused(self, graph, except_for=()): - - # note this helper collects nodes on 'fast path' only - # i.e. the true blocks of specialized checks - def get_nodes_and_parents_recursively(block, kind, acc): - for node in block.nodes(): - if node.kind() == kind: - acc[block].append(node) - elif node.kind() == 'prim::DifferentiableGraph': - get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc) - elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or - node.inputs().__next__().node().kind() == 'prim::TypeCheck'): - get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc) - else: - for inner_block in node.blocks(): - get_nodes_and_parents_recursively(inner_block, kind, acc) - - allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate', - 'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck'} | set(except_for) - - fusion_groups = defaultdict(list) - get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups) - self.assertTrue(len(fusion_groups) == 1, 'got {}'.format(graph)) - (graph, fusion_nodes) = list(fusion_groups.items())[0] - # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes` - self.assertTrue(len(fusion_nodes) == 1, 'got {}'.format(graph)) - self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), - 'got {}'.format(graph)) - - def findFusionGroups(self, graph): result = [] for n in graph.nodes(): diff --git a/test/test_linalg.py b/test/test_linalg.py index 97c7b926faf4..a5478fa3cbfa 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -340,11 +340,13 @@ def test_autograd_and_jit(self, device, dtype): ('norm', (S, S, S), (), 'default_3d'), ('norm', (S,), (inf,), 'vector_inf'), ('norm', (S,), (3.5,), 'vector_3_5'), + ('norm', (S,), (0.5,), 'vector_0_5'), ('norm', (S,), (2,), 'vector_2'), ('norm', (S,), (1,), 'vector_1'), ('norm', (S,), (0,), 'vector_0'), ('norm', (S,), (-inf,), 'vector_neg_inf'), ('norm', (S,), (-3.5,), 'vector_neg_3_5'), + ('norm', (S,), (-0.5,), 'vector_neg_0_5'), ('norm', (S,), (2,), 'vector_neg_2'), ('norm', (S,), (1,), 'vector_neg_1'), ('norm', (S, S), (inf,), 'matrix_inf'), diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 72e8bd6dd108..07c66775c948 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1220,6 +1220,7 @@ def kthvalue_wrapper(tensor, *args, **kwargs): Case(torch.mode, False, False, True, True, values_and_indices), Case(kthvalue_wrapper, False, False, True, True, values_and_indices), Case(torch.median, True, False, True, True, values_and_indices), + Case(torch.nanmedian, True, False, True, True, values_and_indices), ] for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 785ae4c4fb37..3b3616f64220 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -10,7 +10,7 @@ path = os.path.dirname(os.path.realpath(__file__)) aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml') all_operators_with_namedtuple_return = { - 'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', + 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', 'triangular_solve', 'cummax', 'cummin' } @@ -52,7 +52,7 @@ def test_namedtuple_return(self): op = namedtuple('op', ['operators', 'input', 'names', 'hasout']) operators = [ - op(operators=['max', 'min', 'median', 'mode', 'sort', 'topk', 'cummax', 'cummin'], input=(0,), + op(operators=['max', 'min', 'median', 'nanmedian', 'mode', 'sort', 'topk', 'cummax', 'cummin'], input=(0,), names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), diff --git a/test/test_overrides.py b/test/test_overrides.py index b48d9056731f..15fd327c41a8 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -760,51 +760,54 @@ def test_wrapper(self): self.assertTrue(torch.allclose(torch.einsum('ik,jkl,il->ij', [a, b, c]), torch.nn.functional.bilinear(a, c, b))) -# TODO(@anjali411): re-enable this test -# class TestGradCheckOverride(TestCase): -# "Test that wrappers work with gradcheck." -# def test_gradcheck(self): -# from torch.autograd import gradcheck - -# a = wrap(torch.tensor(5.0, dtype=torch.double)) -# b = wrap(torch.tensor(6.0, dtype=torch.double)) - -# a.requires_grad = True -# b.requires_grad = True - -# gradcheck(torch.add, (a, b), raise_exception=False) - -# total_used_attrs = a.used_attrs.union(b.used_attrs) -# total_used_calls = a.used_calls.union(b.used_calls) - -# # These attributes (and the functions below) may change -# # if the gradcheck implementation changes. It's best to -# # aim for attributes that may be commonly present on other -# # Tensor-likes. -# self.assertEqual(total_used_attrs, { -# 'data', -# 'dtype', -# 'is_floating_point', -# 'is_sparse', -# 'layout', -# 'nelement', -# 'new_zeros', -# 'requires_grad', -# 'retain_grad', -# 'size', -# 'stride', -# }) - -# self.assertEqual(total_used_calls, { -# torch.Tensor.new_zeros, -# torch.Tensor.size, -# torch.Tensor.is_floating_point, -# torch.Tensor.nelement, -# torch.Tensor.retain_grad, -# torch.Tensor.stride, -# torch.autograd.grad, -# torch.add, -# }) +class TestGradCheckOverride(TestCase): + "Test that wrappers work with gradcheck." + def test_gradcheck(self): + from torch.autograd import gradcheck, gradgradcheck + + a = wrap(torch.tensor(5.0, dtype=torch.double)) + b = wrap(torch.tensor(6.0, dtype=torch.double)) + + a.requires_grad = True + b.requires_grad = True + + gradcheck(torch.add, (a, b), raise_exception=False) + gradgradcheck(torch.add, (a, b), raise_exception=False) + + total_used_attrs = a.used_attrs.union(b.used_attrs) + total_used_calls = a.used_calls.union(b.used_calls) + + # These attributes (and the functions below) may change + # if the gradcheck implementation changes. It's best to + # aim for attributes that may be commonly present on other + # Tensor-likes. + self.assertEqual(total_used_attrs, { + 'data', + 'device', + 'dtype', + 'is_complex', + 'is_floating_point', + 'is_sparse', + 'layout', + 'nelement', + 'new_zeros', + 'requires_grad', + 'retain_grad', + 'size', + 'stride', + }) + + self.assertEqual(total_used_calls, { + torch.Tensor.new_zeros, + torch.Tensor.size, + torch.Tensor.is_complex, + torch.Tensor.is_floating_point, + torch.Tensor.nelement, + torch.Tensor.retain_grad, + torch.Tensor.stride, + torch.autograd.grad, + torch.add, + }) if __name__ == '__main__': diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 739957569962..122c9f49ff59 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -9,8 +9,9 @@ from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \ LLVMCodeGenExecuted, SimpleIREvalExecuted +from torch.testing._internal.jit_utils import JitTestCase -class BaseTestClass(unittest.TestCase): +class BaseTestClass(JitTestCase): def setUp(self): self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) self.old_profiling_mode = torch._C._jit_set_profiling_mode(True) @@ -21,6 +22,8 @@ def setUp(self): torch._C._jit_override_can_fuse_on_gpu(True) self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() torch._C._jit_set_texpr_fuser_enabled(True) + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) @@ -29,7 +32,15 @@ def tearDown(self): torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) + def assertLastGraphAllFused(self): + self.assertAllFused(torch.jit.last_executed_optimized_graph()) + +def warmup_and_run_forward(f, *args): + for _ in range(torch._C._jit_get_num_profiled_runs() + 1): + results = f(*args) + return results class TestTensorExprFuser(BaseTestClass): def test_easy(self): @@ -41,7 +52,8 @@ def easy(x, y): a = torch.rand(1024) b = torch.rand(1024) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) def test_three_arg(self): @@ -60,7 +72,8 @@ def easy(x, y, z): a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = a.numpy() + b.numpy() + c.numpy() np.testing.assert_allclose(npr, x.numpy()) assert ( @@ -90,7 +103,8 @@ def run_addcmul(x, y, z, w): ), ) - x = traced(rand_a, rand_b, rand_c, rand_d) + x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d) + self.assertLastGraphAllFused() y = run_addcmul(rand_a, rand_b, rand_c, rand_d) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) @@ -120,6 +134,8 @@ def test(x, y, z): b = torch.rand(M, N, device="cuda") c = torch.rand(M, N, device="cuda") x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() np.testing.assert_allclose(npr, x.cpu().numpy()) assert cuda_cg_executed.elapsed_value() >= 1 @@ -152,10 +168,12 @@ def test(x, y, z): ), ) + a = torch.rand(*a_shape, device="cuda") b = torch.rand(*b_shape, device="cuda") c = torch.rand(*c_shape, device="cuda") - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() np.testing.assert_allclose(npr, x.cpu().numpy()) assert cuda_cg_executed.elapsed_value() >= 1 @@ -187,7 +205,8 @@ def np_easy(x, y, z): a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) - x = traced(a, b, c) + x = warmup_and_run_forward(a, b, c) + self.assertLastGraphAllFused() npr = np_easy(a.numpy(), b.numpy(), c.numpy()) np.testing.assert_allclose(npr, x.numpy()) @@ -214,7 +233,8 @@ def np_easy(x, y, z): a = torch.rand(shape) b = torch.rand(shape) c = torch.rand(shape) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = np_easy(a.numpy(), b.numpy(), c.numpy()) np.testing.assert_allclose(npr, x.numpy()) @@ -235,7 +255,8 @@ def np_easy(x, y, z): a = torch.rand(N, N) b = torch.rand(N) c = torch.rand(N, N) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = np_easy(a.numpy(), b.numpy(), c.numpy()) np.testing.assert_allclose(npr, x.numpy()) @@ -257,7 +278,9 @@ def foo_np(x, y, z): z = torch.rand(4) traced = torch.jit.trace(foo, (x, y, z)) - r = traced(x, y, z) + r = warmup_and_run_forward(x, y, z) + self.assertLastGraphAllFused() + rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) np.testing.assert_allclose(r, rnp) @@ -279,7 +302,8 @@ def foo_np(x, y, z): z = torch.rand(1024) traced = torch.jit.trace(foo, (x, y, z)) - r = traced(x, y, z) + r = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) np.testing.assert_allclose(r, rnp) @@ -304,7 +328,8 @@ def constant(x): traced = torch.jit.trace(constant, (torch.tensor([1.0]))) a = torch.tensor([1.0]) - x = traced(a) + x = warmup_and_run_forward(traced, a) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) def test_add_sub(self): @@ -320,7 +345,8 @@ def easy(x, y, z): a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) def test_promotion(self): @@ -335,7 +361,8 @@ def easy(x, y): a = torch.zeros(1024, dtype=torch.int32) b = torch.rand(1024, dtype=torch.float32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) def test_double(self): @@ -353,7 +380,8 @@ def easy(x, y): a = torch.rand(TENSOR_LEN, dtype=torch.double) b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_short(self): @@ -372,7 +400,8 @@ def easy(x, y): a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_char(self): @@ -391,7 +420,8 @@ def easy(x, y): a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_int64_promotion(self): @@ -410,7 +440,8 @@ def easy(x, y): a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_eq(self): @@ -421,7 +452,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) a = torch.zeros(1024, dtype=torch.int32) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_ne(self): @@ -432,7 +464,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) a = torch.zeros(1024, dtype=torch.int32) b = torch.ones(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_ge(self): @@ -445,7 +478,8 @@ def easy(x, y): aa.fill(5) a = torch.from_numpy(aa) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_gt(self): @@ -456,7 +490,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) a = torch.ones(1024, dtype=torch.int32) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_le(self): @@ -469,7 +504,8 @@ def easy(x, y): aa.fill(5) a = torch.from_numpy(aa) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.zeros(1024), x.numpy()) def test_lt(self): @@ -482,7 +518,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) a = torch.ones(1024, dtype=torch.int32, device=dev) b = torch.zeros(1024, dtype=torch.int32, device=dev) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) @suppress_warnings @@ -494,8 +531,9 @@ def test(x, y): a = 8.0 * torch.rand(1024) b = 8.0 * torch.rand(1024) np.testing.assert_allclose( - traced(a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) + warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) ) + self.assertLastGraphAllFused() def test_min_max_reduction(self): def test(x): @@ -503,7 +541,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(1024))) a = 8.0 * torch.rand(1024) - np.testing.assert_allclose(traced(a), np.amin(a.numpy()) + np.amax(a.numpy())) + np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) + self.assertLastGraphAllFused() def test_min_max_reduction2(self): def test(x): @@ -511,7 +550,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(1024))) a = 8.0 * torch.rand(1024) - np.testing.assert_allclose(traced(a), np.amin(a.numpy()) + np.amax(a.numpy())) + np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) + self.assertLastGraphAllFused() def test_min_max_reduction_dim1(self): def test(x): @@ -519,7 +559,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(16, 16))) a = 8.0 * torch.rand(16, 16) - np.testing.assert_allclose(traced(a), np.amin(a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) + self.assertLastGraphAllFused() def test_min_max_reduction_dim1_2(self): def test(x): @@ -527,7 +568,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(16, 16))) a = 8.0 * torch.rand(16, 16) - np.testing.assert_allclose(traced(a)[0], np.amin(a.numpy(), axis=1)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin(a.numpy(), axis=1)) + self.assertLastGraphAllFused() def test_clamp(self): def test(x): @@ -539,7 +581,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) a = 20.0 * torch.rand(1024, device=dev) - 10.0 an = a.cpu().numpy() - np.testing.assert_allclose(traced(a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) + self.assertLastGraphAllFused() def test_relu(self): def test(x): @@ -550,7 +593,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) a = 20.0 * torch.rand(1024, device=dev) - 10.0 an = a.cpu().numpy() - np.testing.assert_allclose(traced(a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) + self.assertLastGraphAllFused() def test_reps(self): def easy(x, y): @@ -562,7 +606,7 @@ def easy(x, y): for _ in range(32): a = torch.ones(1024) b = torch.zeros(1024) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_add_const_rhs(self): @@ -571,7 +615,8 @@ def test(x): traced = torch.jit.trace(test, torch.rand(4)) x = torch.rand(4) - y = traced(x) + y = warmup_and_run_forward(traced, x) + self.assertLastGraphAllFused() np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) def test_int_output(self): @@ -582,7 +627,8 @@ def test(x, y, z): x, y, z = xs xn, yn, zn = [t.numpy() for t in xs] traced = torch.jit.trace(test, (x, y, z)) - res = traced(x, y, z) + res = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() np.testing.assert_allclose(xn * yn * zn, res.numpy()) def test_binary_ops(self): @@ -670,7 +716,8 @@ def test_type_as(x, y): in1 = 20 * torch.rand(1024, device=dev) in2 = 20 * torch.rand(1024, device=dev) traced = torch.jit.trace(torch_fn, (in1, in2)) - x = traced(rand_a, rand_b) + x = warmup_and_run_forward(traced, rand_a, rand_b) + self.assertLastGraphAllFused() y = torch_fn(rand_a, rand_b) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) @@ -842,8 +889,10 @@ def test_threshold(x, y): } device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] + for torch_fn in fns: for dev in device_options: + # print(torch_fn, dev) rand_a = torch.rand(1024, device=dev) rand_b = torch.rand(1024, device=dev) ins = 20 * torch.rand(1024, device=dev) @@ -851,19 +900,22 @@ def test_threshold(x, y): cc.fill(np.nan) nans = torch.from_numpy(cc).to(dev) traced = torch.jit.trace(torch_fn, (ins, ins)) - x = traced(rand_a, rand_b) + x = warmup_and_run_forward(traced, rand_a, rand_b) + self.assertLastGraphAllFused() y = torch_fn(rand_a, rand_b) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) # nans - traced = torch.jit.trace(torch_fn, (ins, ins)) - x = traced(nans, rand_b) - y = torch_fn(nans, rand_b) - try: - np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) - except AssertionError: - # Print extra info before exiting: - print("Failed on dev=", dev, "function=", torch_fn) - np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + # TODO: reenable. Currently all of the tests fail + # traced = torch.jit.trace(torch_fn, (ins, ins)) + # x = warmup_and_run_forward(traced, rand_a, rand_b) + # y = torch_fn(nans, rand_b) + # try: + # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + # print("Succeeded on dev=", dev, "function=", torch_fn) + # except AssertionError: + # # Print extra info before exiting: + # print("Failed on dev=", dev, "function=", torch_fn) + # # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) def test_rand_like(self): devices = ["cuda"] if torch.cuda.is_available() else [] @@ -875,7 +927,8 @@ def run_rand_like(x, y): for device in devices: x = torch.rand(N, device=device) traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) - x_v = traced(x, x) + x_v = warmup_and_run_forward(traced, x, x) + self.assertLastGraphAllFused() x_np = x.cpu().numpy() x1_mean = np.mean(x_np) x2_mean = np.mean(x_np ** 2) @@ -897,10 +950,12 @@ def test_min(x, y): x = torch.tensor([np.nan]) y = torch.tensor([1.0]) - assert np.isnan(tmin(x, y).item()) - assert np.isnan(tmin(y, x).item()) - assert np.isnan(tmax(x, y).item()) - assert np.isnan(tmax(y, x).item()) + assert np.isnan(warmup_and_run_forward(tmin, x, y).item()) + assert np.isnan(warmup_and_run_forward(tmin, y, x).item()) + self.assertLastGraphAllFused() + assert np.isnan(warmup_and_run_forward(tmax, x, y).item()) + assert np.isnan(warmup_and_run_forward(tmax, y, x).item()) + self.assertLastGraphAllFused() def test_remainder(self): def run_remainder(x, y): @@ -916,19 +971,22 @@ def run_remainder(x, y): # random floats traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() y = run_remainder(a, b) np.testing.assert_allclose(x.numpy(), y.numpy()) # div by 0 traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) - x = traced(zeros, a) + x = warmup_and_run_forward(traced, zeros, a) + self.assertLastGraphAllFused() y = run_remainder(zeros, a) np.testing.assert_allclose(x.numpy(), y.numpy()) # numerators and denominatos are nan traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) - x = traced(nans, a) + x = warmup_and_run_forward(traced, nans, a) + self.assertLastGraphAllFused() y = run_remainder(nans, a) np.testing.assert_allclose(x.numpy(), y.numpy()) @@ -941,7 +999,8 @@ def easy(x): traced = torch.jit.trace(easy, (torch.zeros(1024))) a = torch.zeros(1024) - b, c = traced(a) + b, c = warmup_and_run_forward(traced, a) + self.assertLastGraphAllFused() bp = a.numpy() + 1 cp = bp + bp np.testing.assert_allclose(b.numpy(), bp) @@ -956,7 +1015,8 @@ def easy(x): traced = torch.jit.trace(easy, (torch.zeros(1024, 1024))) a = torch.zeros(1024, 1024) - x = traced(a) + x = warmup_and_run_forward(traced, a) + self.assertLastGraphAllFused() npr = a.numpy() npr2 = npr + 1 npr_a, npr_b = np.array_split(npr2, 2) @@ -973,7 +1033,8 @@ def easy(*args): values = [torch.zeros(M, N, device=device) for N in Ns] traced = torch.jit.trace(easy, values) - x = traced(*values) + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() npr = [v.cpu().numpy() for v in values] npr_2 = [v + i for i, v in enumerate(npr)] npr_x = np.concatenate(npr_2, axis=1) @@ -1065,6 +1126,37 @@ def easy(x, y): # FIXME: interp.elapsed_value() also increments due to simplifier assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1 + def _test_softmax(self, device): + def test(x, y): + a = F.softmax(x, dim=0, dtype=torch.float32) + b = F.softmax(y, dim=0, dtype=torch.float32) + c = F.softmax(x, dim=1, dtype=torch.float32) + d = F.softmax(y, dim=1, dtype=torch.float32) + return a + b + c + d + + old = torch._C._jit_set_texpr_reductions_enabled(True) + traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device))) + inp = torch.randn(2, 3, device=device) + res = traced(inp, inp) + # Use eager mode as reference. + ref = test(inp, inp) + np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06) + torch._C._jit_set_texpr_reductions_enabled(old) + + def test_softmax_cpu(self): + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + self._test_softmax('cpu') + # FIXME: interp.elapsed_value() also increments due to simplifier + assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1 + + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + @unittest.skip("global allocs are not supported yet.") + def test_softmax_cuda(self): + cuda = CudaCodeGenExecuted() + self._test_softmax('cuda') + assert cuda.elapsed_value() == 1 + def test_transpose(self): @torch.jit.script def test(x, y, z): @@ -1178,7 +1270,8 @@ def run_rshift(x, y): b = torch.zeros(128, dtype=torch.int32, device=device) inp = torch.ones(128, dtype=torch.int32, device=device) traced = torch.jit.trace(fn, (inp, inp)) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() y = fn(a, b) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) @@ -1189,7 +1282,8 @@ def run_where(x, y): a = torch.rand(1024, dtype=float) b = torch.rand(1024, dtype=float) traced = torch.jit.trace(run_where, (torch.zeros(1024), torch.zeros(1024))) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() y = run_where(a, b) np.testing.assert_allclose(x.numpy(), y.numpy()) @@ -1200,9 +1294,10 @@ def test(x): return (x + y) - (y - x) a = torch.rand(4, device="cuda") scripted = torch.jit.script(test) - scripted(a) + out = warmup_and_run_forward(scripted, a) + self.assertLastGraphAllFused() cx = CudaCodeGenExecuted() - assert torch.allclose(scripted(a), 2 * a) + assert torch.allclose(out, 2 * a) assert cx.elapsed_value() == 1 def test_mask(self): @@ -1214,8 +1309,9 @@ def test(x): for d in devices: x = torch.rand(4, device=d) > 0.5 scripted = torch.jit.script(test) - scripted(x) - assert torch.equal(scripted(x), test(x)) + out = warmup_and_run_forward(scripted, x) + self.assertLastGraphAllFused() + assert torch.equal(out, test(x)) def test_simple_add(self): val = torch._C._jit_get_te_generate_block_code() diff --git a/test/test_torch.py b/test/test_torch.py index ef40a54d4eee..bd159a25b0bd 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1102,42 +1102,7 @@ def compare(t, k, dim, dir): def test_topk_arguments(self): q = torch.randn(10, 2, 10) # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1) - self.assertRaises(TypeError, lambda: q.topk(4, True)) - - def test_median(self): - for size in (155, 156): - x = torch.rand(size, size) - x0 = x.clone() - - nelem = x.nelement() - res1val = torch.median(x) - res2val, _ = torch.sort(x.view(nelem)) - ind = int(math.floor((nelem + 1) / 2) - 1) - - self.assertEqual(res2val[ind], res1val, atol=0, rtol=0) - - res1val, res1ind = torch.median(x, dim=1, keepdim=False) - res2val, res2ind = torch.sort(x) - ind = int(math.floor((size + 1) / 2) - 1) - - self.assertEqual(res2val.select(1, ind), res1val, atol=0, rtol=0) - self.assertEqual(res2val.select(1, ind), res1val, atol=0, rtol=0) - - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.median(x, dim=-1, keepdim=False, out=(res2val, res2ind)) - self.assertEqual(res2val, res1val, atol=0, rtol=0) - self.assertEqual(res2ind, res1ind, atol=0, rtol=0) - - # Test non-default dim - res1val, res1ind = torch.median(x, 0, keepdim=False) - res2val, res2ind = torch.sort(x, 0) - self.assertEqual(res1val, res2val[ind], atol=0, rtol=0) - self.assertEqual(res1ind, res2ind[ind], atol=0, rtol=0) - - # input unchanged - self.assertEqual(x, x0, atol=0, rtol=0) + self.assertRaises(TypeError, lambda: q.topk(4, True)) def test_mode(self): x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE) @@ -4660,6 +4625,7 @@ def add_neg_dim_tests(): ('cummin', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('nanmedian', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]), ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), @@ -9406,7 +9372,7 @@ def test_dim_reduction(self, device): self.assertEqual(x[:, :2].argmax().item(), 2) dim_red_fns = [ - "mean", "median", "mode", "norm", "prod", + "mean", "median", "nanmedian", "mode", "norm", "prod", "std", "sum", "var", "max", "min", "amax", "amin"] def normfn_attr(t, dim, keepdim=False, out=None): @@ -9446,19 +9412,8 @@ def test_multidim(x, dim): x = torch.randn(dims, device=device) test_multidim(x, singleton_dim) - # check reducing median with NaNs - # If the element in the median is a NaN, there can be issues - # when comparining with other nan elements - if fn_name == 'median': - y = torch.full((1, 3), np.nan, dtype=torch.float64, device=device) - y[:, :1] = 1.1 - values, indices = fn_tuple(y, dim=1) - expected_values = torch.tensor([nan], dtype=torch.float64, device=device) - self.assertEqual(values, expected_values) - self.assertTrue(torch.isnan(y.flatten()[indices[0]])) - # check reducing with output kwargs - if fn_name in ['median', 'mode', 'max', 'min']: + if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']: y = torch.randn(5, 3, device=device) values = torch.randn(5, 3, device=device) indices = torch.zeros(5, 3, device=device).long() - 1 @@ -10507,7 +10462,8 @@ def check_norm(a, b, expected_norm, gels_result): @skipCUDAIfNoMagma @skipCPUIfNoLapack @tf32_on_and_off(0.001) - def test_qr(self, device): + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_qr(self, device, dtype): def run_test(tensor_dims, some): A = torch.randn(*tensor_dims, device=device) Q, R = torch.qr(A, some=some) @@ -10542,6 +10498,106 @@ def run_test(tensor_dims, some): for tensor_dims, some in product(tensor_dims_list, [True, False]): run_test(tensor_dims, some) + @dtypes(torch.int, torch.long, torch.float, torch.double) + @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double) + def test_median_real_values(self, device, dtype): + # Generate random 0-3D sizes + sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] + for size in sizes: + # Create random input tensor + t = torch.randn(size, device=device).type(dtype) + t_numpy = t.cpu().numpy() + res = t.median() + self.assertEqual(res, t.nanmedian()) + k = int((t.numel() - 1) / 2) + self.assertEqual(res, t.view(-1).sort()[0][k]) + if t.numel() % 2 == 1: + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + self.assertEqual(res.cpu().numpy(), np.median(t_numpy)) + for dim in range(t.ndim): + res = t.median(dim, True) + self.assertEqual(res, t.nanmedian(dim, True)) + size = t.size(dim) if t.ndim > 0 else 1 + k = int((size - 1) / 2) + self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim)) + self.assertEqual(res[0], t.gather(dim, res[1])) + if size % 2 == 1: + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True)) + + @dtypes(torch.float, torch.double) + @dtypesIfCUDA(torch.half, torch.float, torch.double) + def test_median_nan_values(self, device, dtype): + # Generate random 0-3D sizes + sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] + for size in sizes: + # Create random input tensor with nan values + t = torch.rand(size, device=device, dtype=dtype) + t.masked_fill_(t < 0.1, float('nan')) + t_numpy = t.cpu().numpy() + for op in [torch.median, torch.nanmedian]: + numpy_op = np.median if op == torch.median else np.nanmedian + res = op(t) + num_nan = t.isnan().sum() + if op == torch.median and num_nan > 0: + k = t.numel() - 1 + else: + k = int((t.numel() - num_nan - 1) / 2) + self.assertEqual(res, t.view(-1).sort()[0][k]) + if (t.numel() - num_nan) % 2 == 1: + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + self.assertEqual(res.item(), numpy_op(t.cpu().numpy())) + for dim in range(t.ndim): + res = op(t, dim, True) + size = t.size(dim) if t.ndim > 0 else 1 + num_nan = t.isnan().sum(dim, True) + if op == torch.median: + k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2)) + else: + k = ((size - num_nan - 1) / 2).type(torch.long) + self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k)) + self.assertEqual(res[0], t.gather(dim, res[1])) + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + mask = (size - num_nan) % 2 == 1 + res = res[0].masked_select(mask).cpu() + ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()] + self.assertEqual(res, torch.from_numpy(ref)) + + def test_median_corner_cases(self, device): + def check(op, a, args, key): + t = torch.tensor(a, device=device) + res = op(t, *args) + if not args: + key = torch.tensor(key, device=device) + else: + if len(key) == 1: + key = torch.tensor(key[0], device=device) + res = res[0] + else: + key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device)) + self.assertEqual(res, key) + + nan = float('nan') + check(torch.median, nan, [], nan) + check(torch.nanmedian, nan, [], nan) + check(torch.median, nan, [0], [nan, 0]) + check(torch.nanmedian, nan, [0], [nan, 0]) + check(torch.median, [nan], [0, True], [[nan], [0]]) + check(torch.nanmedian, [nan], [0, True], [[nan], [0]]) + check(torch.median, [nan], [0, True], [[nan], [0]]) + check(torch.nanmedian, [nan], [0, True], [[nan], [0]]) + + # Indices are not deterministic here so can only check values + check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]]) + check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]]) + check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]]) + check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]]) + + @onlyOnCPUAndCUDA @dtypes(torch.float, torch.double) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @@ -13414,14 +13470,12 @@ def test_reduction_empty(self, device): # name, function, identity ('max', torch.max, None), ('amax', torch.amax, None), - ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None), ('argmax', torch.argmax, None), ('min', torch.min, None), ('amin', torch.amin, None), ('argmin', torch.argmin, None), ('mode', torch.mode, None), - ('median', torch.median, None), - + ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None), ('prod', torch.prod, 1.), ('sum', torch.sum, 0.), ('norm', torch.norm, 0.), @@ -13438,12 +13492,27 @@ def test_reduction_empty(self, device): ident_err = 'operation does not have an identity' self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x)) + # median and nanmedian have been updated to follow the new convention for empty tensors + # where it should only fail if the dimension being reduced has size 0. + for name, fn in [('median', torch.median), ('nanmedian', torch.nanmedian)]: + ident_err = 'does not have an identity' + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) + self.assertEqual(fn(x, dim=0)[0].shape, (shape[1], shape[2])) + self.assertEqual(fn(x, dim=0, keepdim=True)[0].shape, (1, shape[1], shape[2])) + self.assertEqual(fn(x, dim=2)[0].shape, (shape[0], shape[1])) + self.assertEqual(fn(x, dim=2, keepdim=True)[0].shape, (shape[0], shape[1], 1)) + for item in fns_to_test: name, fn, identity = item if identity is None: ident_err = 'does not have an identity' + + # Reductions over non-zero dimensions should work even for empty tensors + # See https://github.com/pytorch/pytorch/issues/34907 for a discussion on this. self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2)) self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2, keepdim=True)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) else: @@ -16903,9 +16972,9 @@ def test_addmm_sizes(self, device, dtype): @onlyCUDA def test_matmul_45724(self, device): # https://github.com/pytorch/pytorch/issues/45724 - a = torch.rand(65537, 22, 64).cuda().half() - b = torch.rand(65537, 64, 22).cuda().half() - c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device='cuda') + a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) + b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) + c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half() torch.matmul(a, b, out=c) self.assertEqual(c, cpu_result) diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index 56c44b904b47..1e9bb617425a 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -6,7 +6,7 @@ from torch.utils.mobile_optimizer import optimize_for_mobile from torch.testing import FileCheck import torch.testing._internal.hypothesis_utils as hu -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, slowTest from hypothesis import given, assume from hypothesis import strategies as st import io @@ -1022,6 +1022,8 @@ def forward(self, x): pattern_count_optimized_map, data_shape) + # See https://github.com/pytorch/pytorch/issues/46066 + @slowTest def test_conv1d_with_relu_fc(self): batch_size_list = range(1, 3) input_channels_per_group_list = range(10, 12) diff --git a/third_party/fbgemm b/third_party/fbgemm index fe9164007c33..974d2b41e7ce 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit fe9164007c3392a12ea51a19b0f4e9f40d24f88d +Subproject commit 974d2b41e7ce367545ee244794f857e82ba2a590 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 2af8ee81604e..9c6cd4c578de 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -695,6 +695,9 @@ - name: median(Tensor self) -> Tensor self: evenly_distribute_backward(grad, self, result) +- name: nanmedian(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + # This is in theory incorrect in the following case: # sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value # | at middle position of the @@ -712,6 +715,9 @@ - name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) +- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2ad2f647c2af..e428b1a89cae 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -273,6 +273,15 @@ class Graph: class Value: ... +# Defined in torch/csrc/jit/ir/ir.h +class Block: + ... + +# Defined in torch/csrc/jit/ir/ir.h +class Node: + ... + + # Defined in torch/aten/src/ATen/core/function_schema.h class FunctionSchema: ... diff --git a/torch/_C/_functions.pyi b/torch/_C/_functions.pyi index 4ad76e4f86e5..bbc6606722f3 100644 --- a/torch/_C/_functions.pyi +++ b/torch/_C/_functions.pyi @@ -3,7 +3,7 @@ from typing import AnyStr, List class UndefinedGrad: def __init__(self) -> None: ... - def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ... + def __call__(self, *inputs: Tensor) -> List[Tensor]: ... ... class DelayedError: diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index fc27da90d3d8..33d2249af284 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2202,6 +2202,13 @@ def callable(a, b) -> number See :func:`torch.median` """) +add_docstr_all('nanmedian', + r""" +nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.nanmedian` +""") + add_docstr_all('min', r""" min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 31580e4e0472..1998672d6b34 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3732,9 +3732,9 @@ def merge_dicts(*dicts): .. math:: (\text{start}, - \text{start} + \frac{\text{end} - \text{start}}{\text{steps}}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, \ldots, - \text{start} + (\text{steps} - 1) * \frac{\text{end} - \text{start}}{\text{steps}}, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, \text{end}) """ + """ @@ -4057,9 +4057,9 @@ def merge_dicts(*dicts): .. math:: (\text{base}^{\text{start}}, - \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps}})}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \ldots, - \text{base}^{(\text{start} + (\text{steps} - 1) * \frac{\text{end} - \text{start}}{ \text{steps}})}, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \text{base}^{\text{end}}) """ + """ @@ -4629,13 +4629,12 @@ def merge_dicts(*dicts): r""" median(input) -> Tensor -Returns the median value of all elements in the :attr:`input` tensor. +Returns the median of the values in :attr:`input`. .. note:: The median is not unique for :attr:`input` tensors with an even number of elements. In this case the lower of the two medians is returned. To - compute the mean of both medians in :attr:`input`, use :func:`torch.quantile` - with ``q=0.5`` instead. + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. .. warning:: This function produces deterministic (sub)gradients unlike ``median(dim=0)`` @@ -4653,9 +4652,8 @@ def merge_dicts(*dicts): .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) -Returns a namedtuple ``(values, indices)`` where ``values`` is the median -value of each row of the :attr:`input` tensor in the given dimension -:attr:`dim`. And ``indices`` is the index location of each median value found. +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. @@ -4677,14 +4675,15 @@ def merge_dicts(*dicts): Do not expect the same result when run on CPU and GPU in general. For the same reason do not expect the gradients to be deterministic. - Args: {input} {dim} {keepdim} Keyword args: - out (tuple, optional): the result tuple of two output tensors (max, max_indices) + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. Example:: @@ -4698,6 +4697,60 @@ def merge_dicts(*dicts): torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) """.format(**single_dim_common)) +add_docstr(torch.nanmedian, + r""" +nanmedian(input) -> Tensor + +Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. +When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, +while this function will return the median of the non-``NaN`` elements in :attr:`input`. +If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + +Args: + {input} + +Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + +.. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values +found in the dimension :attr:`dim`. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has +one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the +median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + +Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) +""".format(**single_dim_common)) + add_docstr(torch.quantile, r""" quantile(input, q) -> Tensor diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 531bcc6f27d8..9b1ad2675fe5 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -277,7 +277,7 @@ def gradcheck( identical inputs through the differentiation, the results must either match exactly (default, 0.0) or be within this tolerance. check_undefined_grad (bool, options): if True, check if undefined output grads - are supported and treated as zeros + are supported and treated as zeros, for ``Tensor`` outputs. Returns: True if all differences satisfy allclose condition @@ -462,7 +462,11 @@ def check_undefined_grad_support(output_to_check): return True # All backward functions must work properly if all output grads are undefined - outputs_to_check = [[torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs))]] + outputs_to_check = [[ + torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs)) + # This check filters out Tensor-likes that aren't instances of Tensor. + if isinstance(o, torch.Tensor) + ]] # If there are multiple output grads, we should be able to undef one at a time without error if len(outputs_to_check[0]) > 1: diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 8b05caea5aba..5efb77ea536a 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -660,6 +660,54 @@ void all2all(at::Tensor& input, #endif } +void send( + const at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int dst) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 7) + using namespace torch::cuda::nccl::detail; + NCCL_CHECK(ncclSend( + input.data_ptr(), + input.numel(), + to_nccl_data_type(input), + dst, + to_nccl_comm(comm), + stream.stream())); +#else + AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + +void recv( + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int src) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 7) + using namespace torch::cuda::nccl::detail; + NCCL_CHECK(ncclRecv( + output.data_ptr(), + output.numel(), + to_nccl_data_type(output), + src, + to_nccl_comm(comm), + stream.stream())); +#else + AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index ecf854ec2009..4cbae2e0208a 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -143,6 +143,17 @@ TORCH_CUDA_API void all2all( ncclComm_t comm, at::cuda::CUDAStream& stream); +TORCH_CUDA_API void send( + const at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int dst); + +TORCH_CUDA_API void recv( + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int src); } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 90128e48ee1d..53541960f300 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -946,31 +946,6 @@ void Reducer::prepare_for_backward( std::unordered_set seen; std::vector queue; - // Check that any prior reduction has finished. - // The variable `require_finalize_` is true until all gradients - // have been computed and reduction of all buckets has been kicked off. - if (require_finalize_) { - TORCH_CHECK( - false, - "Expected to have finished reduction in the prior iteration before ", - "starting a new one. ", - "", - "This error indicates that your module has parameters that were ", - "not used in producing loss. ", - "", - "You can enable unused parameter detection by (1) passing the keyword " - "argument `find_unused_parameters=True` to ", - "`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ", - "`forward` function outputs participate in calculating loss. " - "", - "If you already have done the above two steps, then the distributed ", - "data parallel module wasn't able to locate the output tensors in the ", - "return value of your module's `forward` function. ", - "Please include the loss function and the structure of the return ", - "value of `forward` of your module when reporting this issue (e.g. ", - "list, dict, iterable)."); - } - // Reset accounting. expect_autograd_hooks_ = true; next_bucket_ = 0; @@ -1325,6 +1300,11 @@ void Reducer::sync_bucket_indices( } bool Reducer::rebuild_buckets() { + // Ensure reduction for previous backwards pass is finished. If user's model + // has unused parameters for example, this will raise an error recommending to + // run with find_unused_parameters=True, instead of the size mismatch + // exception below. + ensure_prior_reduction_finished(); std::lock_guard lock(mutex_); if (!should_rebuild_buckets() || rebuilt_params_.empty()) { return false; @@ -1381,6 +1361,34 @@ void Reducer::register_comm_hook(std::unique_ptr iface) { comm_hook_ = std::move(iface); } +void Reducer::ensure_prior_reduction_finished() { + // Check that any prior reduction has finished. + // The variable `require_finalize_` is true until all gradients + // have been computed and reduction of all buckets has been kicked off. + if (require_finalize_) { + TORCH_CHECK( + false, + "Expected to have finished reduction in the prior iteration before ", + "starting a new one. ", + "", + "This error indicates that your module has parameters that were ", + "not used in producing loss. ", + "", + "You can enable unused parameter detection by (1) passing the keyword " + "argument `find_unused_parameters=True` to ", + "`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ", + "`forward` function outputs participate in calculating loss. " + "", + "If you already have done the above two steps, then the distributed ", + "data parallel module wasn't able to locate the output tensors in the ", + "return value of your module's `forward` function. ", + "Please include the loss function and the structure of the return ", + "value of `forward` of your module when reporting this issue (e.g. ", + "list, dict, iterable)."); + } + +} + namespace { // Tensors may be coalesced into buckets. Buckets must contain tensors of diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index 5a17dbe6f1c2..efb2060a5533 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -170,6 +170,10 @@ class Reducer { void finalize_backward(); + // Asserts that the reduction for the previous iteration has finished before + // rebuilding buckets or kicking off the next one. + void ensure_prior_reduction_finished(); + // Broadcast rebuilt buckets from rank 0 to other ranks before initializing // the buckets void sync_bucket_indices(std::vector>& bucket_indices); diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index c429fde123c6..1ca3980a5c53 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -503,11 +503,11 @@ void RequestCallbackImpl::processRpcWithErrors( } bool RequestCallbackImpl::cudaAvailable() const { - #ifdef USE_CUDA +#ifdef USE_CUDA return true; - #else +#else return false; - #endif +#endif } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h index 836e496fb069..ed3999cd4a80 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.h +++ b/torch/csrc/distributed/rpc/request_callback_impl.h @@ -55,7 +55,7 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { const int64_t messageId, const std::shared_ptr& responseFuture) const override; - bool cudaAvailable() const override; + bool cudaAvailable() const override; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index d41c8f271104..ac9fbaae21f9 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -591,11 +591,11 @@ Message RequestCallbackNoPython::handleError( } bool RequestCallbackNoPython::cudaAvailable() const { - #ifdef USE_CUDA +#ifdef USE_CUDA return true; - #else +#else return false; - #endif +#endif } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 11c5408c2c35..6bf65f4c2628 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -29,44 +29,6 @@ const std::string kClientActiveCalls = "agent.client_active_calls"; const std::string kServerActiveCalls = "agent.server_active_calls"; const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls"; -inline void checkCPUTensor(const torch::Tensor& tensor) { - TORCH_CHECK( - tensor.device() == at::kCPU, - "TensorPipeAgent only supports CPU tensors by default. Sending " - "GPU tensors using RPC requires explicitly configurations using " - "`set_device_map` on `TensorPipeRpcBackendOptions`. Got a tensor " - "with device ", - tensor.device(), - ", but no device map is specified."); -} - -std::vector getDevicesForTensors( - const std::string& remoteName, - const std::vector& tensors, - const std::unordered_map& deviceMaps) { - const auto workerIter = deviceMaps.find(remoteName); - if (workerIter == deviceMaps.end()) { - for (const auto& tensor : tensors) { - checkCPUTensor(tensor); - } - return {}; - } else { - std::vector deviceIndices; - deviceIndices.reserve(tensors.size()); - const auto& deviceMap = workerIter->second; - for (const auto& tensor : tensors) { - const auto deviceIter = deviceMap.find(tensor.device().index()); - if (deviceIter == deviceMap.end()) { - checkCPUTensor(tensor); - deviceIndices.push_back(-1); - } else { - deviceIndices.push_back(deviceIter->second); - } - } - return deviceIndices; - } -} - } // namespace C10_DEFINE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration); @@ -403,7 +365,7 @@ void TensorPipeAgent::onListenerAccepted( void TensorPipeAgent::pipeRead( const std::shared_ptr& pipe, - std::function fn) { + std::function fn) noexcept { pipe->readDescriptor([fn{std::move(fn)}, pipe]( const tensorpipe::Error& error, tensorpipe::Message tpMessage) mutable { @@ -439,14 +401,11 @@ void TensorPipeAgent::pipeRead( void TensorPipeAgent::pipeWrite( const std::shared_ptr& pipe, Message&& rpcMessage, - std::function fn) { + std::vector&& devices, + std::function fn) noexcept { tensorpipe::Message tpMessage; TensorpipeWriteBuffers tpBuffers; - const auto& deviceMaps = - rpcMessage.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_; - auto devices = getDevicesForTensors( - pipe->getRemoteName(), rpcMessage.tensors(), deviceMaps); std::tie(tpMessage, tpBuffers) = tensorpipeSerialize(std::move(rpcMessage), std::move(devices)); @@ -480,47 +439,18 @@ void TensorPipeAgent::sendCompletedResponseMessage( Message&& responseMessage = std::move(*futureResponseMessage).moveValue(); responseMessage.setId(messageId); if (!error) { - const auto& iter = reverseDeviceMaps_.find(pipe->getRemoteName()); - if (iter == opts_.deviceMaps.end()) { - for (const auto& t : responseMessage.tensors()) { - if (!t.device().is_cpu()) { - responseMessage = createExceptionResponse( - c10::str( - "TensorPipe RPC backend only supports CPU tensors by default," - " please move your tensors to CPU before sending them over " - "RPC, or call `set_device_map` on " - "`TensorPipeRpcBackendOptions` to explicitly configure " - "device mapping. Response device mapping is not available for " - "destination ", - pipe->getRemoteName(), - ", but found tensor on device: ", - t.device()), - responseMessage.id()); - break; - } - } - } else { - const auto& deviceMap = iter->second; - for (const auto& t : responseMessage.tensors()) { - if (!t.device().is_cpu() && - deviceMap.find(t.device().index()) == deviceMap.end()) { - responseMessage = createExceptionResponse( - c10::str( - "TensorPipe RPC backend only supports CPU tensors by default." - " Response device mapping is not available for destination ", - pipe->getRemoteName(), - " for device ", - t.device(), - " but received a tensor on that device."), - responseMessage.id()); - break; - } - } + std::vector devices; + + try { + devices = getDevicesForTensors(pipe->getRemoteName(), responseMessage); + } catch (const std::exception& e) { + responseMessage = createExceptionResponse(e.what(), responseMessage.id()); } pipeWrite( pipe, std::move(responseMessage), + std::move(devices), [this, pipe, messageId](const tensorpipe::Error& error) { if (error) { LOG(WARNING) @@ -539,6 +469,7 @@ void TensorPipeAgent::sendCompletedResponseMessage( pipeWrite( pipe, createExceptionResponse(error->what(), responseMessage.id()), + {}, [this, pipe, messageId](const tensorpipe::Error& error) { if (error) { LOG(WARNING) @@ -670,6 +601,11 @@ std::shared_ptr TensorPipeAgent::send( lock.unlock(); + // Get devices for tensors in the request message. This can throw if device + // maps are not configured properly for this request. + auto devices = + getDevicesForTensors(clientPipe.pipe_->getRemoteName(), requestMessage); + futureResponseMessage->futMsg.addCallback([this]() { TORCH_INTERNAL_ASSERT( this->threadPool_.inThreadPool(), @@ -706,6 +642,7 @@ std::shared_ptr TensorPipeAgent::send( pipeWrite( clientPipe.pipe_, std::move(requestMessage), + std::move(devices), [this, &clientPipe, messageId](const tensorpipe::Error& error) mutable { if (error) { if (error.isOfType() && @@ -1083,6 +1020,53 @@ void TensorPipeAgent::markFutureWithError( } } +std::vector TensorPipeAgent::getDevicesForTensors( + const std::string& remoteName, + const Message& message) const { + const auto& deviceMaps = + message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_; + + const auto errStr = c10::str( + "TensorPipe RPC backend only supports CPU tensors by default, please " + "move your tensors to CPU before sending them over RPC, or call " + "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly " + "configure device mapping. ", + message.isRequest() ? "Request" : "Response", + " device mapping is not available for destination ", + remoteName); + + const auto& iter = deviceMaps.find(remoteName); + if (iter == deviceMaps.end()) { + for (const auto& t : message.tensors()) { + TORCH_CHECK( + t.device().is_cpu(), + errStr, + ", but found tensor on device: ", + t.device()); + } + return {}; + } else { + std::vector deviceIndices; + deviceIndices.reserve(message.tensors().size()); + const auto& deviceMap = iter->second; + for (const auto& t : message.tensors()) { + if (t.device().is_cpu()) { + deviceIndices.push_back(-1); + } else { + const auto deviceIter = deviceMap.find(t.device().index()); + TORCH_CHECK( + deviceIter != deviceMap.end(), + errStr, + " for device ", + t.device(), + " but received a tensor on that device."); + deviceIndices.push_back(deviceIter->second); + } + } + return deviceIndices; + } +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index a3df04011285..b4a500de65be 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -200,14 +200,15 @@ class TensorPipeAgent : public RpcAgent { // by client, and read request messages by server. void pipeRead( const std::shared_ptr&, - std::function); + std::function) noexcept; // TensorPipe write function that could be used to write response // messages by server, and write request messages by client. void pipeWrite( const std::shared_ptr&, Message&& message, - std::function); + std::vector&& devices, + std::function) noexcept; // Callback of listener accept() void onListenerAccepted( @@ -233,6 +234,10 @@ class TensorPipeAgent : public RpcAgent { uint64_t requestSize, const std::string& destWorkerName); + inline std::vector getDevicesForTensors( + const std::string& remoteName, + const Message& message) const; + // When a request+response completes, we need to mark the future message as // complete. However, if its timeout has already expired, it already has an // error set. There is no atomic "test-and-set" way to mark a future complete diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 54da8d8c52c1..d3b08cac4d5c 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -88,8 +88,7 @@ std::tuple tensorpipeSerialize( buffers.pickle.data(), buffers.pickle.size()}); const auto& tensorDataVec = pickler.tensorData(); for (size_t i = 0; i < tensorDataVec.size(); ++i) { - const auto& tensorData = - jit::getWriteableTensorData(tensorDataVec[i]); + const auto& tensorData = jit::getWriteableTensorData(tensorDataVec[i]); // Enforce memory copy if tensor is created from torch::from_blob, means // that the tensor doesn't own the memory. std::string metadata = diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index fa97ea116a0c..41d503c09160 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -527,7 +527,8 @@ void populateRemoteProfiledEvents( if (!foundCpuStart && 0 == strcmp(e.name(), "__start_profile")) { profilerStart = &e; foundCpuStart = true; - } else if (cudaProfilingEnabled && 0 == strcmp(e.name(), "__cuda_start_event")) { + } else if ( + cudaProfilingEnabled && 0 == strcmp(e.name(), "__cuda_start_event")) { e.setCudaUs(e.cpuUs()); auto device = e.device(); TORCH_CHECK( diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index f6e791f0edba..f976af9ac0a3 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -113,7 +113,11 @@ class CudaKernelGenerator : private OptInConstDispatch { // Shared memory if (has_dynamic_smem || has_reductions) { indent() << "alignas(" +#ifndef __HIP_PLATFORM_HCC__ << dataTypeSize(kernel_summary.largest_smem_data_type) +#else + << 8 // for HIP, we want 8-aligned even for smaller datatypes +#endif << ") extern __shared__ char array[];\n"; if (has_dynamic_smem) { diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index a0df3c784778..250038149b0c 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -25,9 +25,13 @@ int FusionExecutor::fusion_id_counter_ = 0; std::string FusionExecutor::getStructuredCode(const std::string& kernel) { // generating cuda code; - std::string code = std::string("namespace ") + - FusionExecutor::kernelNamespace() + " {\n" + - executor_utils::kernelPreamble() + kernel + "}\n"; + std::string code = ""; +#ifdef __HIP_PLATFORM_HCC__ + code += std::string("#include \n") + + std::string("#include \n"); +#endif + code += std::string("namespace ") + FusionExecutor::kernelNamespace() + + " {\n" + executor_utils::kernelPreamble() + kernel + "}\n"; const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG"); if (debug_env && atoi(debug_env)) { diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index af4e127cc548..19f873c90b0a 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -272,10 +272,14 @@ NvrtcFunction nvrtcCompile( at::globalContext().getNVRTC().nvrtcDestroyProgram(&program)); }); +#ifdef __HIP_PLATFORM_HCC__ + std::vector args = {"--std=c++14"}; +#else const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor); std::vector args = { "--std=c++14", compute.c_str(), "-default-device"}; +#endif const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA"); // int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0; @@ -346,6 +350,7 @@ NvrtcFunction nvrtcCompile( // TODO: We do go through different code path, should investigate whether this // has an impact on generated binary. const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN"); +#ifndef __HIP_PLATFORM_HCC__ if (prefix_env) { FUSER_PERF_SCOPE("load CUBIN"); @@ -403,6 +408,12 @@ NvrtcFunction nvrtcCompile( options.data(), option_vals.data())); } +#else + // load ptx directly + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( + &(compiled_kernel_.module), ptx.data())); + +#endif AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleGetFunction( &(compiled_kernel_.function), compiled_kernel_.module, diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h index d30eb3fcda52..a601a956c175 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h +++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h @@ -12,7 +12,7 @@ typedef long long int int64_t; template struct Tensor { - T& operator[](int64_t ind) { + __device__ T& operator[](int64_t ind) { return data[ind]; }; @@ -25,7 +25,7 @@ struct Tensor { // They will be an error as well since zero-length arrays are not allowed. template struct Tensor { - T& operator[](int64_t) { + __device__ T& operator[](int64_t) { return *data; }; @@ -34,6 +34,9 @@ struct Tensor { )"; // Code support for FP16 __half type and intrinsics +#ifdef __HIP_PLATFORM_HCC__ +static auto code_fp16_support = R"()"; +#else static auto code_fp16_support = R"( #define __HALF_TO_US(var) *(reinterpret_cast(&(var))) #define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) @@ -55,7 +58,7 @@ __device__ float __half2float(const __half h) { return val; } )"; - +#endif // struct and code for functions that need random number generation static auto code_random_number_gen = R"( class Philox { @@ -184,6 +187,9 @@ __device__ float randLike(Philox rnd) { }; )"; +// Note: We agressively template functions taking dim3 in the functions below +// because ROCM uses different types for the various dim3 and maps them +// directly to intrinsics, but they're dim3 when used after modification. /* * EXAMPLE USAGE: * blockReduceSum @@ -196,14 +202,14 @@ static auto code_template_block_reduction = R"( // participate, otherwise it is the number of threads. We could start with warp // reductions, then reduce the warps, this could save some shared memory, but // may actually be slower. -template +template __inline__ __device__ void blockReduce( T& out, const T inp_val, Func reduction_op, - const dim3& thread_idx, - const dim3& block_dim, + const _dim3ti& thread_idx, + const _dim3bd& block_dim, T* shared_mem, bool read_write_pred, T init_val) { @@ -324,49 +330,47 @@ static auto code_template_grid_reduction = R"( namespace reduction { // Utility functions -__host__ __device__ __forceinline__ size_t size(const dim3& d) { +template +__host__ __device__ __forceinline__ size_t size(const _dim3& d) { return (size_t)d.x * (size_t)d.y * (size_t)d.z; } -__host__ __device__ __forceinline__ int isize(const dim3& d) { - return d.x * d.y * d.z; -} +#define isize(d) d.x * d.y * d.z -__host__ __device__ __forceinline__ size_t offset(const dim3& pos, const dim3& dim) { +template +__host__ __device__ __forceinline__ size_t offset(const _dim3pos& pos, const _dim3dim& dim) { return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x + (size_t)pos.z * (size_t)dim.x * (size_t)dim.y; } -__host__ __device__ __forceinline__ size_t ioffset(const dim3& pos, const dim3& dim) { - return pos.x + pos.y * dim.x + pos.z * dim.x * dim.y; -} +#define ioffset(pos, dim) pos.x + pos.y * dim.x + pos.z * dim.x * dim.y // Returns dim3 of each reduction segment. -template -__host__ __device__ dim3 dimension_of_reduction_segment(const dim3& grid_dim) { +template +__host__ __device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) { return dim3{X_BLOCK ? grid_dim.x : 1, Y_BLOCK ? grid_dim.y : 1, Z_BLOCK ? grid_dim.z : 1}; } // Returns the number of blocks in each reduction segment. -template -__host__ __device__ size_t size_of_reduction_segment(const dim3& grid_dim) { +template +__host__ __device__ size_t size_of_reduction_segment(const _dim3& grid_dim) { return size(dimension_of_reduction_segment(grid_dim)); } // Returns the total number of reduction segments. -template -__host__ __device__ size_t number_of_reduction_segments(const dim3& grid_dim) { +template +__host__ __device__ size_t number_of_reduction_segments(const _dim3& grid_dim) { return (X_BLOCK ? 1: grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) * (Z_BLOCK ? 1 : grid_dim.z); } // Returns the 1-D index of the segment of thread block of block_idx. -template -__host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx, - const dim3& grid_dim) { +template +__host__ __device__ size_t index_of_reduction_segment(const _dim3bi& block_idx, + const _dim3gd& grid_dim) { size_t seg_idx = 0; if (!Z_BLOCK) seg_idx += block_idx.z; @@ -378,9 +382,9 @@ __host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx, } // Returns the offset of thread block in its reduction segment. -template -__host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx, - const dim3& grid_dim) { +template +__host__ __device__ size_t offset_in_reduction_segment(const _dim3bi& block_idx, + const _dim3gd& grid_dim) { size_t offset = 0; if (Z_BLOCK) offset = offset * grid_dim.z + block_idx.z; @@ -392,23 +396,24 @@ __host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx, } // Returns dim3 of each reduction block. -template -__host__ __device__ dim3 dimension_of_reduction_block(const dim3& block_dim) { +template +__host__ __device__ dim3 dimension_of_reduction_block(const _dim3& block_dim) { return dim3{X_THREAD ? block_dim.x : 1, Y_THREAD ? block_dim.y : 1, Z_THREAD ? block_dim.z : 1}; } // Returns the number of threads of each reduction block. -template -__host__ __device__ int size_of_reduction_block(const dim3& block_dim) { - return isize(dimension_of_reduction_block(block_dim)); +template +__host__ __device__ int size_of_reduction_block(const _dim3& block_dim) { + auto tmp_dim = dimension_of_reduction_block(block_dim); + return isize(tmp_dim); } // Returns the linear offset of a thread in a reduction block. -template -__host__ __device__ int offset_in_reduction_block(const dim3& thread_idx, - const dim3& block_dim) { +template +__host__ __device__ int offset_in_reduction_block(const _dim3ti& thread_idx, + const _dim3bd& block_dim) { int offset = 0; if (Z_THREAD) offset += thread_idx.z; diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 9035a81b7490..167173d816f8 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -87,8 +87,9 @@ struct PeepholeOptimizeImpl { "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*const_inputs=*/attr::size)) { // x.expand(x.size()) == x - if (auto input_type = - node->namedInput(attr::self)->type()->cast()) { + auto input_type = + node->namedInput(attr::self)->type()->cast(); + if (input_type && shape_peepholes_) { auto expanded_sizes = node->get>(attr::size); auto input_type_sizes = input_type->sizes().concrete_sizes(); if (expanded_sizes.has_value() && input_type_sizes && @@ -110,8 +111,9 @@ struct PeepholeOptimizeImpl { input_node->input()->debugName()); node->output()->replaceAllUsesWith(input_node->input()); } - } else if (node->matches( - "aten::type_as(Tensor self, Tensor other) -> Tensor")) { + } else if ( + node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor") && + shape_peepholes_) { // x.type_as(y) == x iff x.type() == y.type() auto self_type = node->input(0)->type()->expect(); auto other_type = node->input(1)->type()->expect(); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index fd49cf672e00..f42d44e8d3b1 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -330,6 +330,8 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { // Create and insert quantized embedding op. Value* none = g->insertConstant(IValue()); Value* zero = g->insertConstant(IValue(0)); + bool pruned_wt = false; + auto pruned_const = g->insertConstant(pruned_wt); if (is_aten_op) { TORCH_CHECK( @@ -340,6 +342,10 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { for (auto i = 1; i < inputs_size - 1; ++i) { qembedding_bag_inputs.push_back(embedding_bag_inputs[i]); } + // The sparse field in the float operator denotes sparse gradients. + // For inference this stands for pruned weights. We currently don't support + // pruning in graph mode API so we set the field to 0 for inference. + qembedding_bag_inputs[5] = pruned_const; } else { TORCH_CHECK( inputs_size == 11, @@ -348,8 +354,8 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets qembedding_bag_inputs.push_back( embedding_bag_inputs[6]); // scale_grad_by_freq - qembedding_bag_inputs.push_back(zero); // zero - qembedding_bag_inputs.push_back(embedding_bag_inputs[8]); // sparse + qembedding_bag_inputs.push_back(zero); // mode + qembedding_bag_inputs.push_back(pruned_const); // pruned_weights qembedding_bag_inputs.push_back( embedding_bag_inputs[9]); // per_sample_weights } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index cf9a93b7f521..a22fbd353dec 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1208,6 +1208,7 @@ class ShapePropagator { "aten::max(Tensor self) -> Tensor", "aten::min(Tensor self) -> Tensor", "aten::median(Tensor self) -> Tensor", + "aten::nanmedian(Tensor self) -> Tensor", "aten::norm(Tensor self, Scalar p) -> Tensor", "aten::std(Tensor self, bool unbiased) -> Tensor", "aten::trace(Tensor self) -> Tensor", @@ -1354,6 +1355,7 @@ class ShapePropagator { "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", }, [](Node* node) -> type_vec_t { diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index fd70f2963b8b..793e7557bbf1 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -86,10 +86,11 @@ bool isSupported(Node* node) { "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", + "aten::lgamma(Tensor self) -> Tensor", "aten::log10(Tensor self) -> Tensor", "aten::log(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", - // TODO: log1p + "aten::log1p(Tensor self) -> Tensor", "aten::exp(Tensor self) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", @@ -150,6 +151,7 @@ bool isSupported(Node* node) { static const OperatorSet supported_reduction_set{ "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor", }; // clang-format on diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1c6a5c968ce3..7b25517ce9ea 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -517,6 +517,13 @@ void initJITBindings(PyObject* module) { getNumProfiledRuns() = num; return old_num; }) + .def( + "_jit_get_num_profiled_runs", + [] { + // pybind can't automatically bind to atomic size_t + size_t num_runs = getNumProfiledRuns(); + return num_runs; + }) .def( "_jit_set_bailout_depth", [](size_t depth) { diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index b2049cb4362d..74e0e75362a6 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1456,6 +1456,8 @@ void initJitScriptBindings(PyObject* module) { m.def( "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining); + m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining); + m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining); m.def("_propagate_shapes", _propagate_shapes); m.def( "_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes); diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index fbb515cea104..e19f801f58c0 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -85,6 +85,17 @@ bool getAutodiffSubgraphInlining() { return autodiff_subgraph_inlining; } +// for debugging it is helpful to be able to force fusion groups +// to be created +static std::atomic fusion_group_inlining(true); +void debugSetFusionGroupInlining(bool state) { + fusion_group_inlining = state; +} + +bool getFusionGroupInlining() { + return fusion_group_inlining; +} + thread_local std::weak_ptr last_executed_optimized_graph; std::shared_ptr lastExecutedOptimizedGraph() { return last_executed_optimized_graph.lock(); diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 6fee30834f1e..c0ab14badc33 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -93,6 +93,9 @@ TORCH_API Node* replaceBlockWithFallbackGraph( // regardless of whether sizes have been specialized or not. TORCH_API void runRequiredPasses(const std::shared_ptr& g); +TORCH_API void debugSetFusionGroupInlining(bool state); +TORCH_API bool getFusionGroupInlining(); + TORCH_API void debugSetAutodiffSubgraphInlining(bool state); TORCH_API std::shared_ptr lastExecutedOptimizedGraph(); diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h index eed7e1f57f1f..12c4617b955d 100644 --- a/torch/csrc/jit/runtime/graph_executor_impl.h +++ b/torch/csrc/jit/runtime/graph_executor_impl.h @@ -38,6 +38,9 @@ void runNondiffOptimization( void debugSetAutodiffSubgraphInlining(bool state); bool getAutodiffSubgraphInlining(); +void debugSetFusionGroupInlining(bool state); +bool getFusionGroupInlining(); + // Tunable parameters for deciding when to create/keep subgraphs of // differentiable code const size_t autodiffSubgraphNodeThreshold = 2; diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index d8af0756e2e6..06a2a15199bd 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -253,7 +253,7 @@ void runDiffGraphPasses(std::shared_ptr& graph) { BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); - FuseTensorExprs(graph); + FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); @@ -312,7 +312,7 @@ void runNoGradOptimizations(std::shared_ptr& graph) { BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); - FuseTensorExprs(graph); + FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index c8f62870f789..28a4136ba829 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -431,10 +431,7 @@ RegisterOperators reg({ aliasAnalysisConservative()), Operator( "aten::set_grad_enabled(bool val) -> ()", - [](Stack* stack) { - torch::GradMode::set_enabled(pop(stack).toBool()); - push(stack, IValue()); - }, + [](Stack* stack) { torch::GradMode::set_enabled(pop(stack).toBool()); }, aliasAnalysisConservative()), }); } // namespace diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 36fe7a1225d6..9604c384c548 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -28,7 +28,7 @@ std::shared_ptr PrepareForStaticRuntime( // remove unused input 0 from graph if (g->inputs().at(0)->type()->is_module()) { - TORCH_INTERNAL_ASSERT(!g->inputs().at(0)->hasUses()); + TORCH_CHECK(!g->inputs().at(0)->hasUses()); g->eraseInput(0); } @@ -54,21 +54,58 @@ StaticRuntime::StaticRuntime( std::shared_ptr g, c10::optional m) : graph_(g) { + // assign register to Value* + std::unordered_map value_to_reg; + for (Value* input : g->inputs()) { + TORCH_CHECK(value_to_reg.count(input) == 0); + size_t index = value_to_reg.size(); + value_to_reg[input] = index; + input_regs_.push_back(index); + } + for (Node* node : graph_->nodes()) { + for (Value* input : node->inputs()) { + TORCH_CHECK(value_to_reg.count(input) > 0); + } + for (Value* output : node->outputs()) { + TORCH_CHECK( + value_to_reg.count(output) == 0, "the graph needs to be in SSA form"); + size_t index = value_to_reg.size(); + value_to_reg[output] = index; + } + } + + TORCH_CHECK(g->outputs().size() > 0); + for (Value* output : g->outputs()) { + TORCH_CHECK(value_to_reg.count(output) > 0); + output_regs_.push_back(value_to_reg[output]); + } + + // initialize registers + reg_.resize(value_to_reg.size()); + // fill workspace_ with constants for (Node* node : graph_->nodes()) { if (node->kind() == prim::Constant) { - CHECK(node->output()->type()->kind() != FunctionType::Kind); - workspace_[node->output()] = toIValue(node->output()).value(); + TORCH_CHECK(node->output()->type()->kind() != FunctionType::Kind); + reg_[value_to_reg[node->output()]] = toIValue(node->output()).value(); } else { - nodes_.emplace_back(node); + std::vector input_regs, output_regs; + for (Value* input : node->inputs()) { + input_regs.push_back(value_to_reg[input]); + } + for (Value* output : node->outputs()) { + output_regs.push_back(value_to_reg[output]); + } + nodes_.emplace_back(node, std::move(input_regs), std::move(output_regs)); } } + if (m) { Method method = m->get_method("forward"); const c10::FunctionSchema& schema = method.function().getSchema(); // remove "self" from function schema - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( schema.arguments().size() >= 1 && schema.arguments()[0].name() == "self"); std::vector args( @@ -81,16 +118,16 @@ StaticRuntime::StaticRuntime( std::vector StaticRuntime::run( const std::vector& inps) const { for (size_t i = 0; i < inps.size(); i++) { - workspace_[graph_->inputs()[i]] = inps[i]; + Input(i) = inps[i]; } for (const auto& n : nodes_) { - n.run(workspace_); + n.run(reg_); } std::vector out; - for (Value* output : graph_->outputs()) { - const IValue& v = workspace_[output]; + for (size_t i = 0; i < graph_->outputs().size(); i++) { + const IValue& v = Output(i); if (v.isTuple()) { auto t = v.toTuple(); for (const auto& el : t->elements()) { @@ -109,21 +146,21 @@ c10::IValue StaticRuntime::run( std::vector stack(args); if (!kwargs.empty()) { // This is not ideal - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( schema_ != nullptr, "Schema is not available. Consider creating the Static Runtime " "with StaticRuntime(const torch::jit::Module& m) instead."); schema_->checkAndNormalizeInputs(stack, kwargs); } for (size_t i = 0; i < stack.size(); i++) { - workspace_[graph_->inputs()[i]] = stack[i]; + Input(i) = stack[i]; } for (const auto& n : nodes_) { - n.run(workspace_); + n.run(reg_); } - return workspace_[graph_->outputs().at(0)]; + return Output(0); } void StaticRuntime::benchmark( @@ -200,14 +237,14 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( std::vector stack(args); if (!kwargs.empty()) { // This is not ideal - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( schema_ != nullptr, "Schema is not available. Consider creating the Static Runtime " "with StaticRuntime(const torch::jit::Module& m) instead."); schema_->checkAndNormalizeInputs(stack, kwargs); } for (size_t i = 0; i < stack.size(); i++) { - workspace_[graph_->inputs()[i]] = stack[i]; + Input(i) = stack[i]; } results.setup_time = timer.MilliSeconds(); @@ -220,7 +257,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( for (int i = 0; i < main_runs; i++) { for (size_t j = 0; j < nodes_.size(); j++) { timer.Start(); - nodes_[j].run(workspace_); + nodes_[j].run(reg_); float millis = timer.MilliSeconds(); results.time_per_node[j] += millis; } @@ -242,12 +279,18 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( return results; } -ProcessedNode::ProcessedNode(Node* node) : node_(node) { +ProcessedNode::ProcessedNode( + Node* node, + std::vector&& input_regs, + std::vector&& output_regs) + : node_(node), + input_regs_(std::move(input_regs)), + output_regs_(std::move(output_regs)) { if (node->kind() != prim::ListConstruct && node->kind() != prim::TupleConstruct && node->kind() != prim::ListUnpack) { const Operator& op = node->getOperator(); - CHECK(op.hasOperation()); + TORCH_CHECK(op.hasOperation()); op_ = op.getOperation(node); } if (canRunOutOfPlace(node)) { @@ -255,19 +298,13 @@ ProcessedNode::ProcessedNode(Node* node) : node_(node) { } } -void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const { +void ProcessedNode::run(std::vector& reg) const { if (!fn_) { std::vector stack; const size_t size = node_->inputs().size(); stack.reserve(size); for (size_t i = 0; i < size; i++) { - Value* v = node_->inputs()[i]; - auto f = workspace.find(v); - TORCH_CHECK( - f != workspace.end(), - "Workspace does not contain Value ", - v->debugName()); - stack.emplace_back(f->second); + stack.emplace_back(Input(i, reg)); } if (op_) { op_->operator()(&stack); @@ -297,10 +334,10 @@ void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const { } DCHECK_EQ(stack.size(), node_->outputs().size()); for (auto i = 0; i < node_->outputs().size(); i++) { - workspace[node_->outputs()[i]] = stack[i]; + Output(i, reg) = std::move(stack[i]); } } else { - fn_->operator()(workspace); + fn_->operator()(this, reg); } } diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 2703da4cf122..62bcd8d63099 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -7,10 +7,6 @@ #include #include -#ifdef FBCODE_CAFFE2 -#include -#endif - namespace torch { namespace jit { @@ -61,12 +57,6 @@ class TORCH_API StaticRuntime { const int warmup_runs, const int main_runs) const; -#ifdef FBCODE_CAFFE2 - using ConstantMap = folly::F14FastMap; -#else - using ConstantMap = std::unordered_map; -#endif - private: explicit StaticRuntime( std::shared_ptr g, // optimized graph @@ -77,8 +67,22 @@ class TORCH_API StaticRuntime { std::unique_ptr schema_{nullptr}; // Static runtime states - // Value table (including weights) - mutable ConstantMap workspace_; + // IValue table (including inputs, outputs, intermediates, and weights) + mutable std::vector reg_; + std::vector input_regs_; // inputs to the graph + std::vector output_regs_; // outputs of the graph + + // Input is readwrite + IValue& Input(size_t i) const { + DCHECK(i < input_regs_.size()); + return reg_[input_regs_[i]]; + } + + // Output is readonly. The writing process happens inside ProcessedNodes + const IValue& Output(size_t i) const { + DCHECK(i < output_regs_.size()); + return reg_[output_regs_[i]]; + } // The nodes we need to run std::vector nodes_; @@ -86,16 +90,36 @@ class TORCH_API StaticRuntime { class ProcessedNode { public: - ProcessedNode(Node* n); - void run(StaticRuntime::ConstantMap& workspace) const; + ProcessedNode( + Node* n, + std::vector&& input_regs, + std::vector&& output_regs); + void run(std::vector& reg) const; + Node* get_node() const { return node_; } + // Input is readonly + const IValue& Input(size_t i, std::vector& reg) const { + DCHECK(i < input_regs_.size()); + return reg[input_regs_[i]]; + } + + // Output is readwrite + IValue& Output(size_t i, std::vector& reg) const { + DCHECK(i < output_regs_.size()); + return reg[output_regs_[i]]; + } + private: Node* node_; c10::optional op_; - c10::optional> fn_; + c10::optional&)>> + fn_; + + std::vector input_regs_; + std::vector output_regs_; }; } // namespace jit diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index fe91920f3c11..8c4ac6e804ee 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -18,136 +18,103 @@ bool canRunOutOfPlace(Node* n) { return out_of_place_nodes.count(str) > 0; } -std::function getOutOfPlaceOperation( - Node* n) { +std::function&)> +getOutOfPlaceOperation(Node* n) { auto create_empty_from = [](const at::Tensor& t) { return at::empty({0}, t.options()); }; if (n->kind() == c10::Symbol::fromQualString("aten::add")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - auto in2 = n->inputs().at(2); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_t = ws.at(in1).toTensor(); - auto in2_s = ws.at(in2).toScalar(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_t)); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_t = p_node->Input(1, reg).toTensor(); + auto in2_s = p_node->Input(2, reg).toScalar(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::add_out(out_t, in0_t, in1_t, in2_s); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::mul")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_t = ws.at(in1).toTensor(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_t)); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_t = p_node->Input(1, reg).toTensor(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::mul_out(out_t, in0_t, in1_t); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::addmm")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - auto in2 = n->inputs().at(2); - auto in3 = n->inputs().at(3); - auto in4 = n->inputs().at(4); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_t = ws.at(in1).toTensor(); - auto in2_t = ws.at(in2).toTensor(); - auto in3_s = ws.at(in3).toScalar(); - auto in4_s = ws.at(in3).toScalar(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_t)); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_t = p_node->Input(1, reg).toTensor(); + auto in2_t = p_node->Input(2, reg).toTensor(); + auto in3_s = p_node->Input(3, reg).toScalar(); + auto in4_s = p_node->Input(4, reg).toScalar(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::clamp")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - auto in2 = n->inputs().at(2); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_s = ws.at(in1).toScalar(); - auto in2_s = ws.at(in2).toScalar(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_t)); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_s = p_node->Input(1, reg).toScalar(); + auto in2_s = p_node->Input(2, reg).toScalar(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::clamp_out(out_t, in0_t, in1_s, in2_s); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::bmm")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_t = ws.at(in1).toTensor(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_t)); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_t = p_node->Input(1, reg).toTensor(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::bmm_out_cpu(out_t, in0_t, in1_t); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::cat")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_tl = ws.at(in0).toTensorVector(); - auto in1_i = ws.at(in1).toInt(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_tl[0])); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_tl = p_node->Input(0, reg).toTensorVector(); + auto in1_i = p_node->Input(1, reg).toInt(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_tl[0]); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::_cat_out_cpu(out_t, in0_tl, in1_i); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::sigmoid")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - if (!ws.count(out)) { - ws.emplace(out, create_empty_from(in0_t)); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + if (p_node->Output(0, reg).isNone()) { + p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = ws.at(out).toTensor(); + auto out_t = p_node->Output(0, reg).toTensor(); at::native::sigmoid_out(out_t, in0_t); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - auto in2 = n->inputs().at(2); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_i = ws.at(in1).toInt(); - auto in2_i = ws.at(in2).toInt(); - ws[out] = at::native::transpose(in0_t, in1_i, in2_i); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_i = p_node->Input(1, reg).toInt(); + auto in2_i = p_node->Input(2, reg).toInt(); + p_node->Output(0, reg) = at::native::transpose(in0_t, in1_i, in2_i); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) { - auto out = n->outputs().at(0); - auto in0 = n->inputs().at(0); - auto in1 = n->inputs().at(1); - auto in2 = n->inputs().at(2); - return [=](StaticRuntime::ConstantMap& ws) { - auto in0_t = ws.at(in0).toTensor(); - auto in1_i = ws.at(in1).toInt(); - auto in2_i = ws.at(in2).toInt(); - ws[out] = at::native::flatten(in0_t, in1_i, in2_i); + return [=](const ProcessedNode* p_node, std::vector& reg) { + auto in0_t = p_node->Input(0, reg).toTensor(); + auto in1_i = p_node->Input(1, reg).toInt(); + auto in2_i = p_node->Input(2, reg).toInt(); + p_node->Output(0, reg) = at::native::flatten(in0_t, in1_i, in2_i); }; } - return [](StaticRuntime::ConstantMap&) { TORCH_CHECK(0); }; + return [](const ProcessedNode*, std::vector&) { TORCH_CHECK(0); }; } } // namespace jit diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index e00416d786d3..e75673ca2431 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -7,8 +7,8 @@ namespace torch { namespace jit { bool canRunOutOfPlace(Node* n); -std::function getOutOfPlaceOperation( - Node* n); +std::function&)> +getOutOfPlaceOperation(Node* n); #define SUPPORTED_OPS(F) \ F(aten::__getitem__) \ diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index 9c065179b9d9..272c1459c5e2 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -15,6 +15,7 @@ class BoundsInference : public IRVisitor { void visit(const FunctionCall* v) override; void visit(const Load* v) override; void visit(const Store* v) override; + void visit(const ReduceOp* v) override; void visit(const For* v) override; void visit(const Block* v) override; @@ -39,6 +40,12 @@ void BoundsInference::visit(const Store* v) { IRVisitor::visit(v); } +void BoundsInference::visit(const ReduceOp* v) { + accesses_[v->accumulator()].push_back( + {kLoad, v->output_args(), v->output_args()}); + IRVisitor::visit(v); +} + void BoundsInference::visit(const For* v) { v->body()->accept(this); for (auto& pair : accesses_) { diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 54eb974beb12..838408402b1f 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -206,6 +206,7 @@ std::vector TensorExprKernel::inferSizesForValue( case aten::relu: case aten::log: case aten::log10: + case aten::log1p: case aten::log2: case aten::exp: case aten::expm1: @@ -337,6 +338,11 @@ std::vector TensorExprKernel::inferSizesForValue( shape[dim] = concat_size; return shape; } + + case aten::softmax: + // Output of Softmax has the same shape as input 0. + return sizesForValue(v->node()->input(0)); + case aten::slice: throw std::runtime_error( "Shape info is not implemented for this kind of node"); @@ -902,6 +908,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { "aten_log10", v, [](const ExprHandle& a) { return log10(a); }); } break; + case aten::log1p: { + return computeOneOperand( + "aten_log1p", v, [](const ExprHandle& a) { return log1p(a); }); + } break; + case aten::log2: { return computeOneOperand( "aten_log2", v, [](const ExprHandle& a) { return log2(a); }); @@ -1232,6 +1243,10 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { return computeSum(v); } + case aten::softmax: { + return computeSoftmax(v); + } + default: { throw std::runtime_error("Unhandled node kind"); } @@ -1554,6 +1569,110 @@ Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) { reduction_info.reductionDims); } +Tensor* TensorExprKernel::computeSoftmax(const torch::jit::Value* v) { + // Softmax is computed as follows: + // softmax(vi) = exp(vi) / sum(exp(vi)) + // + // In order to avoid overflow issues due to exp of a large number, we + // subtract the max of that dim before computing exp. + // softmax(vi) = exp(vi - max(vi)) / sum(exp(vi - max(vi))) + // + // This is implemented as 4 loopnests: + // - First loop computes the max over the softmax dim. + // - Second loop computes exp for every element in v after subtracting + // the max of the softmax dim it belongs to. + // - Third loop computes the sum over the softmax dim. + // - Final loop computes softmax for every element in v. + + TORCH_INTERNAL_ASSERT(v->node()->inputs().size() == 3); + auto output_dims = dimsFromSizes(sizesForValue(v)); + + // We do not handle None for dims (input 1) because that is supposed to + // be deprecated. + TORCH_INTERNAL_ASSERT(v->node()->input(1)->node()->kind() == prim::Constant); + size_t softmax_dim = v->node()->input(1)->node()->i(attr::value); + TORCH_INTERNAL_ASSERT(softmax_dim < output_dims.size()); + + std::vector non_softmax_dims; + for (size_t i = 0; i < output_dims.size(); ++i) { + if (i != softmax_dim) { + non_softmax_dims.push_back(output_dims[i]); + } + } + + // Softmax implementation includes two reductions, one to find the max and + // the other to calculate the sum along the softmax dim. These reductions + // will have the softmax dimension as the inner most loop. So, the innermost + // index in the indices will refer to the softmax dimension. + + // Update the indices by moving the softmax dimension index to the + // appropriate position. + auto move_softmax_dim_index_to_pos = [&](const ParameterList& indices) { + std::vector new_indices; + for (auto ind : indices) { + new_indices.push_back(ind); + } + for (size_t i = softmax_dim; i < indices.size() - 1; ++i) { + new_indices[i + 1] = indices[i]; + } + new_indices[softmax_dim] = indices[indices.size() - 1]; + return new_indices; + }; + + // Remove the index corresponding to the softmax dimension. + auto remove_softmax_dim_index = [&](const ParameterList& indices) { + std::vector new_indices; + for (size_t i = 0; i < indices.size(); ++i) { + if (i != softmax_dim) { + new_indices.push_back(indices[i]); + } + } + return new_indices; + }; + + auto convert_indices_to_expr_handle = [&](const ParameterList& indices) { + std::vector new_indices(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + new_indices[i] = indices[i]; + } + return new_indices; + }; + + c10::optional dtype = ToDtype(ScalarType::None); + auto maybe_dtype = v->node()->get(attr::dtype); + if (maybe_dtype && !maybe_dtype->isNone()) { + dtype = ToDtype(static_cast(maybe_dtype->toInt())); + } + + auto max = Reduce( + "aten_softmax_max", + non_softmax_dims, + Maximum(dtype.value()), + [&](ParameterList& indices) { + return tensorOrConstant( + v->node()->inputs()[0], move_softmax_dim_index_to_pos(indices)); + }, + {output_dims[softmax_dim]}); + auto e = + Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) { + auto inp = tensorOrConstant( + v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); + return exp(inp - max->call(remove_softmax_dim_index(indices))); + }); + auto sum = Reduce( + "aten_softmax_sum", + non_softmax_dims, + Sum(), + [&](ParameterList& indices) { + return e->call(move_softmax_dim_index_to_pos(indices)); + }, + {output_dims[softmax_dim]}); + auto res = Compute("aten_softmax", output_dims, [&](ParameterList& indices) { + return e->call(indices) / sum->call(remove_softmax_dim_index(indices)); + }); + return res; +} + TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo( const torch::jit::Node* node) { std::vector axes; diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index bcc5682f68a5..028b18112ab2 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -121,6 +121,8 @@ class TORCH_API TensorExprKernel { Tensor* computeSum(const torch::jit::Value* v); + Tensor* computeSoftmax(const torch::jit::Value* v); + Tensor* computeValue(const torch::jit::Value* v); void flattenTensors(BackendType backendType); diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 5c093707dc95..5d3df6ed2f67 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1428,6 +1428,283 @@ static std::vector getOuterLoopIndexes(Stmt* s) { return res; } +class CacheReplacer : public IRMutator { + public: + CacheReplacer( + const Buf* buffer, + const Buf* cache, + std::vector& offsets) + : buf_(buffer), cache_(cache), offsets_(offsets) {} + + private: + const Expr* mutate(const FunctionCall* v) override { + const Buf* buf = v->tensor()->buf(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + // for reductions the size of tensor->args() is not equal to the size of the + // output buffer, but they should be ordered so that the output args are at + // the beginning even if the loops are reordered later. + // Map indices to call-parameters. + std::vector newIndices; + for (size_t i = 0; i < offsets_.size(); ++i) { + const Expr* index = v->param(i)->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new Load(cache_, newIndices, new IntImm(1)); + } + + const Expr* mutate(const Load* v) override { + const Buf* buf = v->buf(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + // Map indices to call-parameters. + std::vector newIndices; + TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); + for (size_t i = 0; i < v->indices().size(); ++i) { + const Expr* index = v->indices()[i]->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new Load(cache_, newIndices, v->mask()); + } + + Stmt* mutate(const Store* v) override { + const Buf* buf = v->buf(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + const Expr* newValue = v->value()->accept_mutator(this); + + // Map indices to call-parameters. + std::vector newIndices; + TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); + for (size_t i = 0; i < v->indices().size(); ++i) { + const Expr* index = v->indices()[i]->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new Store(cache_, newIndices, newValue, v->mask()); + } + + const Expr* mutate(const ReduceOp* v) override { + const Buf* buf = v->accumulator(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + const Expr* newBody = v->body().node()->accept_mutator(this); + + // Map indices to call-parameters. + std::vector newIndices; + TORCH_INTERNAL_ASSERT(offsets_.size() == v->output_args().size()); + for (size_t i = 0; i < v->output_args().size(); ++i) { + const Expr* index = v->output_args()[i]->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new ReduceOp( + cache_, + ExprHandle(newBody), + v->interaction(), + newIndices, + v->reduce_args()); + } + + const Buf* buf_; + const Buf* cache_; + std::vector& offsets_; +}; + +LoopNest::AccessResult LoopNest::cacheAccesses( + const Buf* producer, + const std::string& name, + Stmt* consumer) { + ReduceOp* reduceOp{nullptr}; + auto reductions = NodeFinder::find(consumer); + for (auto* ro : reductions) { + if (ro->accumulator() != producer) { + continue; + } + + if (reduceOp) { + throw std::runtime_error( + "can only cache accesses used by at most a single reduceOp"); + return {nullptr, nullptr}; + } + + reduceOp = ro; + } + + auto consumer_bounds_info = inferBounds(consumer); + auto bounds_it = consumer_bounds_info.find(producer); + if (bounds_it == consumer_bounds_info.end()) { + throw std::runtime_error("consumer does not use the Tensor produced"); + return {nullptr, nullptr}; + } + + std::vector starts; + std::vector stops; + + bool hasReads = false; + bool hasWrites = false; + // Find the safe size of the temprorary buffer by determining the outer + // extents of a union of all bounds. + for (const TensorAccessBoundsInfo& p : bounds_it->second) { + hasReads |= p.kind == kLoad; + hasWrites |= p.kind == kStore; + + for (size_t i = 0; i < p.start.size(); i++) { + if (starts.size() <= i) { + starts.push_back(p.start[i]); + } else { + starts[i] = + IRSimplifier::simplify(new Min(starts[i], p.start[i], true)); + } + + if (stops.size() <= i) { + stops.push_back(p.stop[i]); + } else { + stops[i] = IRSimplifier::simplify(new Max(stops[i], p.stop[i], true)); + } + } + } + + std::vector var_names = {"i", "j", "k", "l", "m", "n", "o", "p"}; + std::vector tmp_dims; + std::vector new_loop_vars; + std::vector new_loop_vars_expr; + + // Determine the size of the cache, and create a loop var for each dimension. + for (size_t i = 0; i < starts.size(); ++i) { + const Expr* dim = IRSimplifier::simplify( + new Add(new Sub(stops[i], starts[i]), new IntImm(1))); + + tmp_dims.push_back(dim); + + new_loop_vars.push_back(new Var(var_names[i % var_names.size()], kInt)); + new_loop_vars_expr.push_back(new_loop_vars[i]); + } + + // Create the var. + Buf* tmp_buf = new Buf(new Var(name, kHandle), tmp_dims, producer->dtype()); + + // determine the offsets for calls into the cache based off the loop start of + // each axis. + std::vector tmp_params; + for (size_t i = 0; i < new_loop_vars.size(); ++i) { + tmp_params.push_back(new Add(new_loop_vars[i], starts[i])); + } + + // Replace acceses to the producer in the consumer with the cache. + CacheReplacer replacer(producer, tmp_buf, starts); + Stmt* new_consumer = + IRSimplifier::simplify(consumer->accept_mutator(&replacer)); + + intermediate_bufs_.insert(tmp_buf); + + // replace the old consumer with the replaced consumer. + Block* consumer_block = nullptr; + // if the consumer is a block, we should mutate it in place. + if ((consumer_block = dynamic_cast(consumer))) { + consumer_block->clear(); + consumer_block->append_stmt(new_consumer); + } else { + consumer_block = dynamic_cast(consumer->get_parent()); + assert(consumer_block); + consumer_block->replace_stmt(consumer, new_consumer); + } + + // If there's a reduction we can't just write the result straight back to the + // original buffer, since after parallelism the writes will race. Instead we + // need to create a new ReduceOp. + if (reduceOp) { + // reduceOp means we had both loads and stores. + + // Init cache to 0. + Stmt* tmp_init = new Store( + tmp_buf, + new_loop_vars_expr, + getImmediateByType(tmp_buf->dtype(), 0), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_init = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_init); + } + + consumer_block->insert_stmt_before(tmp_init, new_consumer); + + // Reduce back to the original buffer: + Stmt* tmp_store = new Store( + producer, + tmp_params, + new ReduceOp( + producer, + ExprHandle(new Load(tmp_buf, new_loop_vars_expr, new IntImm(1))), + reduceOp->interaction(), + tmp_params, + {}), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_store = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + } + + consumer_block->insert_stmt_after(tmp_store, new_consumer); + + return std::make_pair(tmp_buf, new_consumer); + } + + if (hasReads) { + // Fill the cache with values from the consumer. + Stmt* tmp_store = new Store( + tmp_buf, + new_loop_vars_expr, + new Load(producer, tmp_params, new IntImm(1)), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_store = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + } + + consumer_block->insert_stmt_before(tmp_store, new_consumer); + } + + if (hasWrites) { + // sync the cache back to the producer buf. + Stmt* tmp_store = new Store( + producer, + tmp_params, + new Load(tmp_buf, new_loop_vars_expr, new IntImm(1)), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_store = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + } + + consumer_block->insert_stmt_after(tmp_store, new_consumer); + } + + return std::make_pair(tmp_buf, new_consumer); +} + /* * WHAT COMPUTE_AT DOES * ==================== diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 8eebf82b9886..5cdcb0ef5c91 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -78,11 +78,21 @@ class TORCH_API LoopNest { void setGPUBlockIndex(For* f, int idx); void setGPUThreadIndex(For* f, int idx); + using AccessResult = std::pair; + // Insert a cache for the consumer's usages of the buffer produced in + // consumer, and redirect reads and writes in the consumer to that cache. + // Returns a pair of the new cache buffer, and the new rewritten consumer. + AccessResult cacheAccesses( + const Buf* producer, + const std::string& name, + Stmt* consumer); + // Insert a temporary computation of statement S in the scope of loop AT. // S is assumed to be a Store or a Block containing a Store. Along with the // computation itself, this transformation inserts Alloc/Free statements for // the temporary buffer used in the computation. void computeAt(Stmt* s, For* at); + void rfactor( const Expr* f, const Var* reduction_var, diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index 7aec71f6b56d..abdaea147c00 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -163,6 +163,13 @@ class TORCH_API Block : public StmtNode { return stmts_; } + void clear() { + for (auto* s : stmts_) { + set_parent(s, nullptr); + } + stmts_.clear(); + } + explicit Block(const std::vector& stmts) { for (Stmt* s : stmts) { if (s->get_parent()) { diff --git a/torch/fft/__init__.py b/torch/fft/__init__.py index b3ffdb24b4a0..5ae82660eeaa 100644 --- a/torch/fft/__init__.py +++ b/torch/fft/__init__.py @@ -688,6 +688,7 @@ beginning of the tensor. This also works for multi-dimensional transforms: + >>> x = torch.fft.fftfreq(5, d=1/5) + 0.1 * torch.fft.fftfreq(5, d=1/5).unsqueeze(1) >>> x tensor([[ 0.0000, 1.0000, 2.0000, -2.0000, -1.0000], diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 5c362a42fcf5..83035666d7e9 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -4,6 +4,62 @@ namespace c10d { +std::string opTypeToString(OpType opType) { + switch (opType) { + case OpType::BROADCAST: + return "BROADCAST"; + case OpType::ALLREDUCE: + return "ALLREDUCE"; + case OpType::ALLREDUCE_COALESCED: + return "ALLREDUCE_COALESCED"; + case OpType::REDUCE: + return "REDUCE"; + case OpType::ALLGATHER: + return "ALLGATHER"; + case OpType::ALLGATHER_BASE: + return "ALLGATHER_BASE"; + case OpType::ALLGATHER_COALESCED: + return "ALLGATHER_COALESCED"; + case OpType::GATHER: + return "GATHER"; + case OpType::SCATTER: + return "SCATTER"; + case OpType::REDUCE_SCATTER: + return "REDUCE_SCATTER"; + case OpType::ALLTOALL_BASE: + return "ALLTOALL_BASE"; + case OpType::ALLTOALL: + return "ALLTOALL"; + case OpType::SEND: + return "SEND"; + case OpType::RECV: + return "RECV"; + case OpType::RECVANYSOURCE: + return "RECVANYSOURCE"; + case OpType::BARRIER: + return "BARRIER"; + case OpType::UNKNOWN: + return "UNKNOWN"; + default: + TORCH_INTERNAL_ASSERT("Unknown op type!"); + } + return "UNKNOWN"; +} + +bool isP2POp(OpType opType) { + return opType == OpType::SEND || opType == OpType::RECV || + opType == OpType::RECVANYSOURCE; +} + +ProcessGroup::Work::Work() : rank_(-1), opType_(OpType::UNKNOWN) {} + +ProcessGroup::Work::Work(int rank, OpType opType) + : rank_(rank), opType_(opType) {} + +OpType ProcessGroup::Work::retrieveOpType() { + return opType_; +} + ProcessGroup::Work::~Work() {} bool ProcessGroup::Work::isCompleted() { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 59d40d2427a8..01d835d913cd 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -15,6 +15,32 @@ constexpr auto kNoTimeout = std::chrono::milliseconds(0); namespace c10d { +enum class OpType : std::uint8_t { + BROADCAST = 0, + ALLREDUCE = 1, + ALLREDUCE_COALESCED = 2, + REDUCE = 3, + ALLGATHER = 4, + ALLGATHER_BASE = 5, + ALLGATHER_COALESCED = 6, + GATHER = 7, + SCATTER = 8, + REDUCE_SCATTER = 9, + ALLTOALL_BASE = 10, + ALLTOALL = 11, + SEND = 12, + RECV = 13, + RECVANYSOURCE = 14, + BARRIER = 15, + UNKNOWN = 100, +}; + +// Converts OpType to human readable string. +std::string opTypeToString(OpType opType); + +// Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) +bool isP2POp(OpType opType); + // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // @@ -39,6 +65,10 @@ class ProcessGroup { public: class Work { public: + Work(); + + Work(int rank, OpType opType); + virtual ~Work(); // Checks if request has completed. Non-blocking operation. @@ -93,6 +123,8 @@ class ProcessGroup { // work. Only NCCL backend is currently supported. virtual c10::intrusive_ptr getFuture(); + OpType retrieveOpType(); + protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. @@ -106,6 +138,12 @@ class ProcessGroup { std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; + + // Current rank of the node. + const int rank_; + + // Operation type that this work object refers to. + OpType opType_; }; explicit ProcessGroup(int rank, int size); diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 667519a0b8db..0df2d2eccfa6 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -217,8 +217,22 @@ constexpr int64_t kSynchronizeBusyWaitMillis = 10; const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; -ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) - : devices_(devices), workStartTime_(std::chrono::steady_clock::now()) { +std::ostream& operator<<( + std::ostream& output, + const ProcessGroupNCCL::WorkNCCL& workNCCL) { + return output << "WorkNCCL(" + << "OpType=" << opTypeToString(workNCCL.opType_) + << ", TensorShape=" << (*workNCCL.outputs_)[0].sizes() + << ", Timeout(ms)=" << workNCCL.opTimeout_.count() << ")"; +} + +ProcessGroupNCCL::WorkNCCL::WorkNCCL( + const std::vector& devices, + int rank, + OpType opType) + : Work(rank, opType), + devices_(devices), + workStartTime_(std::chrono::steady_clock::now()) { // Creates the CUDA event wrappers // Note: The actual events are lazily created when first recorded to with // DEFAULT_FLAGS = cudaEventDisableTiming. @@ -228,7 +242,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) - : std::enable_shared_from_this(w), + : Work(w.rank_, w.opType_), + std::enable_shared_from_this(w), devices_(w.devices_), cudaEvents_(w.cudaEvents_), ncclComms_(w.ncclComms_), @@ -357,9 +372,19 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( ncclComm->ncclCommAbort(); const auto& storeKey = getNcclAbortedCommStoreKey( buildNcclUniqueIdStr(ncclComm->getNcclId())); - store_->set(storeKey, {}); - LOG(INFO) << "Wrote aborted communicator id to store: " << storeKey; + auto rankStr = std::to_string(rank_); + store_->set( + storeKey, + std::vector( + reinterpret_cast(rankStr.data()), + reinterpret_cast(rankStr.data()) + + rankStr.size())); + LOG(INFO) << "[Rank " << rank_ + << "] Wrote aborted communicator id to store: " << storeKey; } + LOG(INFO) << "[Rank " << rank_ + << "] Caught collective operation timeout for work: " + << (*this); throw std::runtime_error("Operation timed out!"); } // Check for errors and throw appropriate exception. @@ -412,7 +437,6 @@ void ProcessGroupNCCL::parseNcclAsyncErrorHandling() { auto val = std::stoi(errorHandle); if (val == 1) { asyncErrorHandling_ = true; - LOG(INFO) << "[Rank " << rank_ << "] NCCL Async Error Handling enabled."; } else if (val != 0) { throw std::runtime_error( "Invalid value for environment variable: " + @@ -465,6 +489,12 @@ ProcessGroupNCCL::ProcessGroupNCCL( if (asyncErrorHandling_) { workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this); } + LOG(INFO) << "[Rank " << rank_ + << "] ProcessGroupNCCL initialized with following options:" + << "\nNCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << "\nNCCL_BLOCKING_WAIT: " << blockingWait_ + << "\nTIMEOUT(ms): " << opTimeout_.count() + << "\nUSE_HIGH_PRIORITY_STREAM: " << isHighPriorityStream_; } ProcessGroupNCCL::~ProcessGroupNCCL() { @@ -495,12 +525,17 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { void ProcessGroupNCCL::ncclCommWatchdog() { try { + LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!"; ncclCommWatchdogInternal(); - LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread terminated normally"; + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated normally"; } catch (std::exception& e) { - LOG(INFO) << "NCCL watchdog thread terminated with exception: " << e.what(); + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated with exception: " + << e.what(); } catch (...) { - LOG(INFO) << "NCCL watchdog thread terminated with unknown exception"; + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated with unknown exception"; } } @@ -521,10 +556,12 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } if (checkForNCCLErrors(ncclComms)) { - LOG(INFO) << "Received NCCL errors for communicators in the cache"; + LOG(INFO) << "[Rank " << rank_ + << "] Received NCCL errors for communicators in the cache"; if (blockingWait_ || asyncErrorHandling_) { - LOG(INFO) << "Aborting communicators that received errors"; + LOG(INFO) << "[Rank " << rank_ + << "] Aborting communicators that received errors"; // We abort NCCL communicators that have received errors from this // thread, and exceptions are set on the corresponding work objects. // The workCleanupThread will then loop through the unfinished @@ -541,7 +578,8 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // a communicator the application receives an exception and its // their responsibility to destroy the process group and recreate // it to recover from errors. - abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); + abortedCommIds.emplace( + buildNcclUniqueIdStr(ncclComm->getNcclId())); } } } @@ -560,7 +598,10 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // Check for Timeouts in the WorkNCCL Operations, and abort all // communicators accordingly. if (work.timedOut()) { - LOG(INFO) << "[" << rank_ << "] caught collective operation timeout"; + LOG(INFO) + << "[Rank " << rank_ + << "] Watchdog caught collective operation timeout for work: " + << work; std::exception_ptr exception_ptr = std::make_exception_ptr( std::runtime_error("NCCL Operation Timed Out")); work.setException(exception_ptr); @@ -583,8 +624,15 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { for (const auto& abortedCommId : abortedCommIds) { abortedComms_.emplace(abortedCommId); const auto& storeKey = getNcclAbortedCommStoreKey(abortedCommId); - store_->set(storeKey, {}); - LOG(INFO) << "Watchdog wrote aborted communicator id to store: " + auto rankStr = std::to_string(rank_); + store_->set( + storeKey, + std::vector( + reinterpret_cast(rankStr.data()), + reinterpret_cast(rankStr.data()) + + rankStr.size())); + LOG(INFO) << "[Rank " << rank_ + << "] Watchdog wrote aborted communicator id to store: " << storeKey; } @@ -598,7 +646,11 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { store_->wait( {storeKey}, std::chrono::milliseconds(kWaitForAbortCommStoreKey)); - LOG(INFO) << "Found key in store: " << storeKey + auto val = store_->get(storeKey); + std::string rank(reinterpret_cast(val.data()), val.size()); + LOG(INFO) << "[Rank " << rank_ + << "] Found key in store: " << storeKey + << ", from rank: " << rank << ", aborting appropriate communicators"; // Now abort the appropriate communicators. @@ -609,7 +661,9 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { ncclComm->ncclCommAbort(); } abortedComms_.emplace(commId); - LOG(INFO) << "Aborted communicators for key in store: " << storeKey; + LOG(INFO) << "[Rank " << rank_ + << "] Aborted communicators for key in store: " + << storeKey; } catch (std::exception& e) { VLOG(1) << "Did not find key in store: " << storeKey << ", error: " << e.what(); @@ -685,15 +739,32 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( return nullptr; } -void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { +void ProcessGroupNCCL::broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + OpType opType, + const std::string& p2pKey, + int p2pRank) { + // For collective operations: // For every NCCL communicator that we create we need to broadcast // a unique ID from rank 0 to all other ranks. This broadcast is // done by rank 0 setting a key in the store and all other ranks // retrieving the contents of that key. A single process group // may create multiple NCCL communicators, so we use a sequence // number to differentiate between them. - std::string storeKey = std::to_string(ncclCommCounter_++); - if (rank_ == 0) { + // For point-to-point operations: + // The sequence number will only be increased on 2 out of all the + // processes in a Process Group. So all following collective + // operations will see different sequence numbers which will cause + // runtime errors. To avoid that, use the src:target pair instead + // of sequence number for p2p communications. + + std::string storeKey; + if (!isP2POp(opType)) { + storeKey = std::to_string(ncclCommCounter_++); + } else { + storeKey = p2pKey; + } + if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) { auto vec = std::vector( reinterpret_cast(ncclID), reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); @@ -708,8 +779,9 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { std::vector>& ProcessGroupNCCL::getNCCLComm( const std::string& devicesKey, const std::vector& devices, - NCCLCommType commType, - int p2pRank) { + OpType opType, + int p2pRank, + bool isSendRecvSelf) { // Sanity check if (devicesKey.empty()) { throw std::runtime_error( @@ -737,12 +809,15 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( ncclUniqueId ncclID; // For point-to-point communication, lower rank of the two will get unique id. - if (rank_ == 0 || (commType != NCCLCommType::COLL && p2pRank == 0)) { + if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) { C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID)); } - // Broadcast so that each process can have a unique NCCL ID - broadcastUniqueNCCLID(&ncclID); + // For point-to-point communication on the same process, don't need broadcast. + if (!isSendRecvSelf) { + // Broadcast so that each process can have a unique NCCL ID + broadcastUniqueNCCLID(&ncclID, opType, devicesKey, p2pRank); + } at::cuda::OptionalCUDAGuard gpuGuard; @@ -775,12 +850,16 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // GPU world size and GPU rank int numRanks, rank; - if (commType == NCCLCommType::COLL) { + if (!isP2POp(opType)) { numRanks = getSize() * devices.size(); rank = getRank() * devices.size() + i; + } else if(isSendRecvSelf) { + // Same process send and recv. + numRanks = 1; + rank = 0; } else { - // For point-to-point operation, there are only 2 processes involved so - // the GPU rank is either 0 or 1. + // For point-to-point operation, there are only 2 processes involved so + // the GPU rank is either 0 or 1. numRanks = 2; rank = p2pRank; } @@ -798,7 +877,8 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( std::lock_guard lock(mutex_); if (futureNCCLCallbackStreams_[deviceIndex] == nullptr) { futureNCCLCallbackStreams_[deviceIndex] = - std::make_shared(at::cuda::getStreamFromPool(isHighPriorityStream_)); + std::make_shared( + at::cuda::getStreamFromPool(isHighPriorityStream_)); } } @@ -929,8 +1009,10 @@ std::vector flatten_for_scatter_gather( } // namespace std::shared_ptr ProcessGroupNCCL::initWork( - std::vector devices) { - return std::make_shared(devices); + std::vector devices, + int rank, + OpType opType) { + return std::make_shared(devices, rank, opType); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -964,7 +1046,8 @@ void ProcessGroupNCCL::workEnqueue( } } ProcessGroupNCCL::Options::Options() - : opTimeout(kProcessGroupNCCLOpTimeoutMillis), isHighPriorityStream(false) {} + : opTimeout(kProcessGroupNCCLOpTimeoutMillis), + isHighPriorityStream(false) {} template std::shared_ptr ProcessGroupNCCL::collective( @@ -972,16 +1055,17 @@ std::shared_ptr ProcessGroupNCCL::collective( std::vector& outputs, Fn fn, PreProcess pre, - PostProcess post) { + PostProcess post, + OpType opType) { const auto devices = getDeviceList(inputs); const auto key = getKeyFromDevices(devices); - auto& ncclComms = getNCCLComm(key, devices); + auto& ncclComms = getNCCLComm(key, devices, opType); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices); + auto work = initWork(devices, rank_, opType); // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. @@ -1025,11 +1109,13 @@ std::shared_ptr ProcessGroupNCCL::collective( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; (*work->cudaEvents_)[i].record(ncclStream); work->ncclComms_[i] = ncclComms[i]; - work->blockingWait_ = blockingWait_; - work->opTimeout_ = opTimeout_; - work->store_ = store_; } + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->opTimeout_ = opTimeout_; + work->store_ = store_; + if (asyncErrorHandling_) { workEnqueue(work); } @@ -1042,21 +1128,22 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensors, Fn fn, int peer, - NCCLCommType commType, + OpType opType, PreProcess pre, PostProcess post) { const auto devices = getDeviceList(tensors); const auto key = getKeySendRecv(rank_, peer); - int p2pRank = rank_ < peer ? 0 : 1; - auto& ncclComms = getNCCLComm(key, devices, commType, p2pRank); + int p2pRank = rank_ <= peer ? 0 : 1; + auto isSendRecvSelf = rank_ == peer; + auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank, isSendRecvSelf); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices); + auto work = initWork(devices, rank_, opType); - if (commType == NCCLCommType::RECV) { + if (opType == OpType::RECV) { // Store references to outputs and futureNCCLCallbackStream to be used by // WorkNCCL::getFuture. work->outputs_ = std::make_shared>(tensors); @@ -1087,7 +1174,7 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; // For point-to-point communication, NCCL ranks can only // be 0 or 1. - int p2pTargetRank = 1 - p2pRank; + int p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; C10D_NCCL_CHECK( fn(tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank)); } @@ -1112,13 +1199,15 @@ template std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, - Fn fn) { + Fn fn, + OpType opType) { return collective( inputs, outputs, fn, [](std::vector&) {}, - [](std::vector&) {}); + [](std::vector&) {}, + opType); } template @@ -1126,12 +1215,12 @@ std::shared_ptr ProcessGroupNCCL::pointToPoint( std::vector& tensor, Fn fn, int peer, - NCCLCommType type) { + OpType opType) { return pointToPoint( tensor, fn, peer, - type, + opType, [](std::vector&) {}, [](std::vector&) {}); } @@ -1156,7 +1245,8 @@ std::shared_ptr ProcessGroupNCCL::allreduce( getNcclReduceOp(opts.reduceOp, input), comm, stream.stream()); - }); + }, + OpType::ALLREDUCE); } std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( @@ -1186,7 +1276,8 @@ std::shared_ptr ProcessGroupNCCL::broadcast( root, comm, stream.stream()); - }); + }, + OpType::BROADCAST); } std::shared_ptr ProcessGroupNCCL::reduce( @@ -1211,7 +1302,8 @@ std::shared_ptr ProcessGroupNCCL::reduce( root, comm, stream.stream()); - }); + }, + OpType::REDUCE); } std::shared_ptr ProcessGroupNCCL::allgather( @@ -1254,7 +1346,8 @@ std::shared_ptr ProcessGroupNCCL::allgather( outputTensors[i][j].copy_(outputFlattened[i][j], true); } } - }); + }, + OpType::ALLGATHER); } std::shared_ptr ProcessGroupNCCL::allgather_coalesced( @@ -1306,7 +1399,8 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( } } }, - [&](std::vector& ncclStreams) {}); + [&](std::vector& ncclStreams) {}, + OpType::REDUCE_SCATTER); } std::shared_ptr ProcessGroupNCCL::barrier( @@ -1375,7 +1469,8 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( comm, stream); return ncclSuccess; - }); + }, + OpType::ALLTOALL_BASE); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1407,7 +1502,8 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( getNcclDataType(input.scalar_type()), comm, stream.stream()); - }); + }, + OpType::ALLTOALL_BASE); } } @@ -1417,21 +1513,16 @@ std::shared_ptr ProcessGroupNCCL::send( int /* unused */) { check_gpu_tensors(tensors); auto ret = pointToPoint( - tensors, + tensors, [&](at::Tensor& input, ncclComm_t comm, at::cuda::CUDAStream& stream, int dst) { - return ncclSend( - input.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - dst, - comm, - stream.stream()); + torch::cuda::nccl::send(input, comm, stream, dst); + return ncclSuccess; }, dstRank, - NCCLCommType::SEND); + OpType::SEND); return ret; } @@ -1440,22 +1531,17 @@ std::shared_ptr ProcessGroupNCCL::recv( int srcRank, int /* unused */) { check_gpu_tensors(tensors); - auto ret= pointToPoint( + auto ret = pointToPoint( tensors, [&](at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream, int src) { - return ncclRecv( - output.data_ptr(), - output.numel(), - getNcclDataType(output.scalar_type()), - src, - comm, - stream.stream()); + torch::cuda::nccl::recv(output, comm, stream, src); + return ncclSuccess; }, srcRank, - NCCLCommType::RECV); + OpType::RECV); return ret; } #else diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b8b3d5aabd35..3b5f6b5c02df 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -23,13 +24,6 @@ constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT"; // Handling with NCCL. constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; -// NCCL Commmunication type -enum class NCCLCommType : std::uint8_t { - SEND = 0, - RECV, - COLL, -}; - // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -71,7 +65,7 @@ class ProcessGroupNCCL : public ProcessGroup { public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices - WorkNCCL(const std::vector& devices); + WorkNCCL(const std::vector& devices, int rank, OpType opType); // Copy constructor doing partial copy without outputs_. Cleanup thread // monitors and removes finished works. However it will deadlock when // destructs outputs_ tensors who are view tensors in autograd graph. @@ -147,6 +141,10 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) const; + friend std::ostream& operator<<( + std::ostream& output, + const WorkNCCL& workNCCL); + private: // Helper function for synchronize void synchronizeInternal(std::chrono::milliseconds timeout); @@ -166,6 +164,7 @@ class ProcessGroupNCCL : public ProcessGroup { // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; + // Store streams that run FutureNCCL then callbacks. std::vector> futureNCCLCallbackStreams_; @@ -460,22 +459,29 @@ class ProcessGroupNCCL : public ProcessGroup { protected: // Helper that broadcasts nccl unique ID to all ranks through the store - void broadcastUniqueNCCLID(ncclUniqueId* ncclID); + void broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + OpType opType, + const std::string& devicesKey, + int p2pRank); // Helper that either looks up the cached NCCL communicators or creates // a new set of NCCL communicators as a cache entry std::vector>& getNCCLComm( const std::string& devicesKey, const std::vector& devices, - NCCLCommType commType = NCCLCommType::COLL, - int p2pRank = 0); + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms); virtual std::shared_ptr initWork( - std::vector devices); + std::vector devices, + int rank, + OpType opType); private: // Helper that encapsulates work shared across all collective communication @@ -488,14 +494,16 @@ class ProcessGroupNCCL : public ProcessGroup { std::shared_ptr collective( std::vector& input, std::vector& output, - Fn fn); + Fn fn, + OpType opType); template std::shared_ptr collective( std::vector& input, std::vector& output, Fn fn, PreProcess pre, - PostProcess post); + PostProcess post, + OpType opType); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective @@ -505,13 +513,13 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& tensor, Fn fn, int peer, - NCCLCommType commType); + OpType opType); template std::shared_ptr pointToPoint( std::vector& tensor, Fn fn, int peer, - NCCLCommType commType, + OpType opType, PreProcess pre, PostProcess post); @@ -537,8 +545,8 @@ class ProcessGroupNCCL : public ProcessGroup { // accordingly. void parseNcclBlockingWait(); - // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets asyncErrorHandling_ - // accordingly. + // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets + // asyncErrorHandling_ accordingly. void parseNcclAsyncErrorHandling(); void workCleanupLoop(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 93f633938e18..0df197d17cbb 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -16,8 +16,10 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLSimulateErrors( const std::vector& devices, - bool simulate_error) - : WorkNCCL(devices), simulate_error_(simulate_error) {} + bool simulate_error, + int rank, + c10d::OpType opType) + : WorkNCCL(devices, rank, opType), simulate_error_(simulate_error) {} std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) @@ -55,8 +57,11 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { } std::shared_ptr initWork( - std::vector devices) override { - return std::make_shared(devices, simulate_error_); + std::vector devices, + int rank, + c10d::OpType opType) override { + return std::make_shared( + devices, simulate_error_, rank, opType); } size_t getNCCLCommCacheSize() { @@ -79,8 +84,11 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLTimedoutErrors( const std::vector& devices, - bool set_timedout_error) - : WorkNCCL(devices), set_timedout_error_(set_timedout_error) {} + bool set_timedout_error, + int rank, + c10d::OpType opType) + : WorkNCCL(devices, rank, opType), + set_timedout_error_(set_timedout_error) {} private: bool isCompleted() override { @@ -105,9 +113,11 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { set_timedout_error_(false) {} std::shared_ptr initWork( - std::vector devices) override { + std::vector devices, + int rank, + c10d::OpType opType) override { return std::make_shared( - devices, set_timedout_error_); + devices, set_timedout_error_, rank, opType); } void set_timedout_error() { diff --git a/torch/nn/functional.py b/torch/nn/functional.py index d1575c14323a..9a48c54eb4fb 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2816,6 +2816,8 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N """) channel_shuffle = _add_docstr(torch.channel_shuffle, r""" +channel_shuffle(input, groups) -> Tensor + Divide the channels in a tensor of shape :math:`(*, C , H, W)` into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, while keeping the original tensor shape. diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 06a565700550..a9128a430fea 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -31,6 +31,7 @@ from .transformer import TransformerEncoder, TransformerDecoder, \ TransformerEncoderLayer, TransformerDecoderLayer, Transformer from .flatten import Flatten, Unflatten +from .channelshuffle import ChannelShuffle __all__ = [ 'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', @@ -54,5 +55,5 @@ 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', - 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss' + 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' ] diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 8c660fcb73a0..2344543f23f8 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -195,7 +195,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) self.mode = mode - self.sparse = sparse + self.pruned_weights = False self.include_last_offset = include_last_offset self.dtype = dtype @@ -203,11 +203,11 @@ def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_ compressed_indices_mapping: Optional[Tensor] = None) -> Tensor: if self.dtype == torch.quint4x2: return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0, - self.sparse, per_sample_weights, compressed_indices_mapping, + self.pruned_weights, per_sample_weights, compressed_indices_mapping, self.include_last_offset) else: return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0, - self.sparse, per_sample_weights, compressed_indices_mapping, + self.pruned_weights, per_sample_weights, compressed_indices_mapping, self.include_last_offset) def _get_name(self): diff --git a/torch/nn/utils/fusion.py b/torch/nn/utils/fusion.py index e0f512a24f3e..c4f164ee3b40 100644 --- a/torch/nn/utils/fusion.py +++ b/torch/nn/utils/fusion.py @@ -15,7 +15,7 @@ def fuse_conv_bn_eval(conv, bn): def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): if conv_b is None: - conv_b = bn_rm.new_zeros(bn_rm.shape) + conv_b = torch.zeros_like(bn_rm) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: diff --git a/torch/overrides.py b/torch/overrides.py index dc434e9c1f58..8b363b9f2bf0 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -469,6 +469,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: return_indices=False, ceil_mode=False: -1), torch.mean: lambda input, dim=None: -1, torch.median: lambda input, dim=None: -1, + torch.nanmedian: lambda input, dim=None: -1, torch.meshgrid: lambda *tensors, **kwargs: -1, torch.min: lambda input, out=None: -1, torch.minimum: lambda input, other, out=None: -1, diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 74dee6ea3cf3..5839a9f85160 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -2,7 +2,6 @@ from torch.fx import ( GraphModule, Proxy, - symbolic_trace, map_arg ) @@ -413,9 +412,8 @@ def insert_observer(node, observer, device): if isinstance(obj, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] - traced_standalone_module = symbolic_trace(standalone_module) prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx - observed_standalone_module = prepare(traced_standalone_module, {'': qconfig}) + observed_standalone_module = prepare(standalone_module, {'': qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module(observed_standalone_module) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 7fa3fcbee82f..47cbc4deb019 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -1,17 +1,20 @@ +import torch +from torch.fx import GraphModule # type: ignore +from torch.fx import symbolic_trace # type: ignore +from torch.fx.symbolic_trace import Tracer # type: ignore from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 -from torch.fx import GraphModule # type: ignore from .fx.utils import graph_pretty_str # noqa: F401 def _check_is_graph_module(model): if not isinstance(model, GraphModule): raise ValueError( 'input model must be a GraphModule, ' + - 'please run torch.fx.symbolic_trace on your model before using ' + - 'quantize_fx. Got type:' + str(type(model))) + 'Got type:' + str(type(model)) + ' Please make ' + + 'sure to follow the tutorials.') -def fuse_fx(graph_module, inplace=False): - r""" Fuse modules in preparation for quantization +def _fuse_fx(graph_module, inplace=False): + r""" Internal helper function to fuse modules in preparation for quantization Args: graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) @@ -20,10 +23,21 @@ def fuse_fx(graph_module, inplace=False): fuser = Fuser() return fuser.fuse(graph_module, inplace) -def _prepare_fx(graph_module, qconfig_dict, inplace, is_standalone_module=False): +class CustomTracer(Tracer): + def __init__(self, standalone_modules): + super().__init__() + self.standalone_modules = standalone_modules + + def is_leaf_module(self, m, module_qualified_name): + return (m.__module__.startswith('torch.nn') and + not isinstance(m, torch.nn.Sequential)) or \ + module_qualified_name in self.standalone_modules + + +def _prepare_fx(model, qconfig_dict, inplace, is_standalone_module=False): r""" Internal helper function for prepare_fx Args: - `graph_modul`e, `qconfig_dict`, `inplace`: see docs for :func:`~torch.quantization.prepare_fx` + `model`, `qconfig_dict`, `inplace`: see docs for :func:`~torch.quantization.prepare_fx` `is_standalone_module`: a boolean flag indicates whether we are quantizing a standalone module or not, a standalone module is a submodule of the parent module that is not inlined in the @@ -31,12 +45,19 @@ def _prepare_fx(graph_module, qconfig_dict, inplace, is_standalone_module=False) the way we quantize standalone module is described in: :func:`~torch.quantization._prepare_standalone_module_fx` """ - _check_is_graph_module(graph_module) - graph_module = fuse_fx(graph_module, inplace) + # symbolically trace the model + if is_standalone_module: + # standlone module is traced before quantizing standalone modules + graph_module = symbolic_trace(model) + else: + standalone_modules = qconfig_dict.get('standalone_module_name', []) + # skipping tracing standalone modules when tracing top level module + graph_module = GraphModule(model, CustomTracer(standalone_modules).trace(model)) + graph_module = _fuse_fx(graph_module, inplace) quantizer = Quantizer() return quantizer.prepare(graph_module, qconfig_dict, inplace=True, is_standalone_module=is_standalone_module) -def _prepare_standalone_module_fx(graph_module, qconfig_dict, inplace=False): +def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False): r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. standalone_module means it a submodule that is not inlined in parent module, @@ -51,14 +72,32 @@ def _prepare_standalone_module_fx(graph_module, qconfig_dict, inplace=False): custom module is observed or not """ - return _prepare_fx(graph_module, qconfig_dict, inplace, is_standalone_module=True) + return _prepare_fx(model, qconfig_dict, inplace, is_standalone_module=True) + + +def fuse_fx(model, inplace=False): + r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. + Fusion rules are defined in torch.quantization.fx.fusion_pattern.py + Args: + `model`: a torch.nn.Module model + `inplace`: flag for whether we fuse modules inplace or out of place + + Example: + ```python + from torch.quantization import fuse_fx + m = Model().eval() + m = fuse_fx(m) + ``` + """ + assert not model.training, 'fuse_fx only works on models in eval mode' + graph_module = torch.fx.symbolic_trace(model) + return _fuse_fx(graph_module, inplace) -def prepare_fx(graph_module, qconfig_dict, inplace=False): +def prepare_fx(model, qconfig_dict, inplace=False): r""" Prepare a model for post training static quantization Args: - `graph_module`: model from symbolic_tracing (torch.fx.symbolic_trace), must be - an eval model + `model`: torch.nn.Module model, must be in eval mode `qconfig_dict`: qconfig_dict is a dictionary with the following configurations: qconfig_dict = { # optional, global config @@ -89,9 +128,6 @@ def prepare_fx(graph_module, qconfig_dict, inplace=False): # optional: specify the path for standalone modules # These modules are symbolically traced and quantized as one unit - # User should also skip symbolic tracing through these modules - # so that the call to the submodule appears as one call_module - # node in the forward graph of the GraphModule "standalone_module_name": [ "submodule.standalone" ] @@ -122,21 +158,16 @@ def calibrate(model, data_loader): prepared_model = prepare_fx(graph_module, qconfig_dict) # Run calibration calibrate(prepared_model, sample_inference_data) + ``` """ - assert not graph_module.training, 'prepare_fx only works for models in' + \ - 'eval mode' - return _prepare_fx(graph_module, qconfig_dict, inplace) - -def prepare_static_fx(graph_module, qconfig_dict, inplace=False): - assert not graph_module.training, 'prepare_static_fx only works for models in ' + \ + assert not model.training, 'prepare_fx only works for models in' + \ 'eval mode' - return prepare_fx(graph_module, qconfig_dict, inplace) + return _prepare_fx(model, qconfig_dict, inplace) -def prepare_qat_fx(graph_module, qconfig_dict, inplace=False): +def prepare_qat_fx(model, qconfig_dict, inplace=False): r""" Prepare a model for quantization aware training Args: - `graph_module`: model from symbolic_tracing (torch.fx.symbolic_trace), must be - a train model + `model`: torch.nn.Module model, must be in train mode `qconfig_dict`: see :func:`~torch.quantization.prepare_fx` `inplace`: flag for carry out model transformations in-place, the original module is mutated @@ -151,22 +182,22 @@ def prepare_qat_fx(graph_module, qconfig_dict, inplace=False): from torch.quantization import get_default_qat_qconfig from torch.quantization import prepare_fx - float_model.train() - graph_module = torch.fx.symbolic_trace(float_model) qconfig = get_default_qat_qconfig('fbgemm') def train_loop(model, train_data): model.train() for image, target in data_loader: ... + float_model.train() qconfig_dict = {"": qconfig} - prepared_model = prepare_fx(graph_module, qconfig_dict) + prepared_model = prepare_fx(float_model, qconfig_dict) # Run calibration train_loop(prepared_model, train_loop) + ``` """ - assert graph_module.training, 'prepare_qat_fx only works for models in ' + \ + assert model.training, 'prepare_qat_fx only works for models in ' + \ 'train mode' - return _prepare_fx(graph_module, qconfig_dict, inplace) + return _prepare_fx(model, qconfig_dict, inplace) def _convert_fx(graph_module, inplace, debug, is_standalone_module=False): """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` diff --git a/torch/testing/_internal/codegen/random_topo_test.py b/torch/testing/_internal/codegen/random_topo_test.py index 2eed0bad4e43..cf27fadff314 100644 --- a/torch/testing/_internal/codegen/random_topo_test.py +++ b/torch/testing/_internal/codegen/random_topo_test.py @@ -285,7 +285,7 @@ def runDefaultTestWithSeed(seed): jit_o = traced_model(seed_tensor, *tensor_list) validate_o = zip(o, jit_o) for oo, jit_oo in validate_o: - if not oo.allclose(jit_oo, equal_nan=True): + if not oo.allclose(jit_oo, atol=1e-5, equal_nan=True): return False return True diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2a99ae643931..af08a7f9d953 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -80,7 +80,7 @@ def __init__(self, # NOTE: if the op is unspecified it is assumed to be under the torch namespace if op is None: - assert hasattr(torch, self.name) + assert hasattr(torch, self.name), f"Can't find torch.{self.name}" self.op = op if op else getattr(torch, self.name) self.method_variant = getattr(torch.Tensor, name) if hasattr(torch.Tensor, name) else None inplace_name = name + "_" @@ -818,10 +818,14 @@ def method_tests(): ('median', (S, S, S), (1,), 'dim', (), [0]), ('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('median', (), NO_ARGS, 'scalar'), - # TODO: https://github.com/pytorch/pytorch/issues/30818 - ('median', (), (0,), 'scalar_dim', (), [0], [expectedFailureCUDA]), - ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0], [expectedFailureCUDA]), - # END TODO + ('median', (), (0,), 'scalar_dim', (), [0]), + ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0]), + ('nanmedian', (S, S, S), NO_ARGS), + ('nanmedian', (S, S, S), (1,), 'dim', (), [0]), + ('nanmedian', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), + ('nanmedian', (), NO_ARGS, 'scalar'), + ('nanmedian', (), (0,), 'scalar_dim', (), [0]), + ('nanmedian', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('mode', (S, S, S), NO_ARGS), ('mode', (S, S, S), (1,), 'dim', (), [0]), ('mode', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 235e88f3c823..baec963ce10a 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -617,6 +617,63 @@ def test_batch_isend_irecv_nccl(self): self._barrier() + @skip_if_no_gpu + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_self_nccl(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + p2p_op_list = [] + + if rank == 0: + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id) + recv_op = dist._P2POp(dist.irecv, recv_tensor, 0) + p2p_op_list.append(recv_op) + send_op = dist._P2POp(dist.isend, send_tensor, 0) + p2p_op_list.append(send_op) + + reqs = dist._batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + @skip_if_no_gpu + @skip_if_small_worldsize + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_no_rank_zero_nccl(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + p2p_op_list = [] + + if rank == 1: + peer = 2 + elif rank == 2: + peer = 1 + + if rank in [1, 2]: + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id) + recv_op = dist._P2POp(dist.irecv, recv_tensor, peer) + p2p_op_list.append(recv_op) + send_op = dist._P2POp(dist.isend, send_tensor, peer) + p2p_op_list.append(send_op) + + reqs = dist._batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + + self._barrier() + # GLOO Batch SEND RECV CPU @unittest.skipIf(BACKEND != "gloo", "GLOO Batch Send Recv CPU") def test_batch_isend_irecv_gloo(self): @@ -3605,6 +3662,37 @@ def forward(self, x): # isolate failure hangs. torch.cuda.synchronize(device=self.rank) + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_unused_params_rebuild_buckets_exception(self): + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10, bias=False) + self.net2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + return self.net1(x) + + ddp = torch.nn.parallel.DistributedDataParallel( + ToyModel().cuda(self.rank), device_ids=[self.rank] + ) + for i in range(2): + inp = torch.rand(1, 10) + if i > 0: + # On 2nd iteration, this will fail during rebuild_buckets, + # but we should report an error regarding unused parameters + # since that is the underlying root cause. + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration", + ): + ddp(inp).sum().backward() + else: + ddp(inp).sum().backward() + @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"}) @skip_if_lt_x_gpu(2) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 896bd1cc78db..143af4be06fa 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -4535,7 +4535,7 @@ def _add_to_gpu(x, y): def _test_device_maps_missing_config(self, mode): dst = worker_name((self.rank + 1) % self.world_size) errMsg = ( - "TensorPipeAgent only supports CPU tensors by default.*" + "TensorPipe RPC backend only supports CPU tensors by default.*" "`set_device_map` on `TensorPipeRpcBackendOptions`" ) @@ -4580,6 +4580,28 @@ def _test_device_maps_missing_config_response(self, mode): def test_device_maps_missing_config(self): self._test_device_maps_missing_config(RPCExecMode.SYNC) + @skip_if_lt_x_gpu(1) + def test_device_maps_missing_config_not_timeout(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options + ) + + timeout = rpc.get_rpc_timeout() + + tik = time.time() + self._test_device_maps_missing_config(RPCExecMode.SYNC) + rpc.shutdown() + tok = time.time() + + self.assertTrue(tok - tik < timeout) + @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config_loop(self): diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 732260573ecc..e1dd6ca9fb93 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -24,7 +24,7 @@ from functools import reduce from itertools import chain from torch._six import StringIO -from typing import Any, Dict +from collections import defaultdict import inspect import io @@ -34,6 +34,7 @@ import sys import tempfile import textwrap +from typing import Any, Dict, List RUN_CUDA = torch.cuda.is_available() RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1 @@ -89,6 +90,7 @@ def __exit__(self, type, value, traceback): return True +FUSION_GROUP = "prim::TensorExprGroup" class JitTestCase(TestCase): _do_cuda_memory_leak_check = True @@ -132,6 +134,35 @@ def tearDown(self): self.clearHooks() clear_class_registry() + def assertAllFused(self, graph, except_for=()): + + # note this helper collects nodes on 'fast path' only + # i.e. the true blocks of specialized checks + def get_nodes_and_parents_recursively(block, kind, acc): + for node in block.nodes(): + if node.kind() == kind: + acc[block].append(node) + elif node.kind() == 'prim::DifferentiableGraph': + get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc) + elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or + node.inputs().__next__().node().kind() == 'prim::TypeCheck'): + get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc) + else: + for inner_block in node.blocks(): + get_nodes_and_parents_recursively(inner_block, kind, acc) + + allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate', + 'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck'} | set(except_for) + + fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list) + get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups) + self.assertTrue(len(fusion_groups) == 1, 'got {}'.format(graph)) + (graph, fusion_nodes) = list(fusion_groups.items())[0] + # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes` + self.assertTrue(len(fusion_nodes) == 1, 'got {}'.format(graph)) + self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), + 'got {}'.format(graph)) + def _isHookExceptionOk(self, e): se = str(e) allowed = ("Could not export Python function", diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index feecc39acd87..37ee47e2e6ec 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -348,6 +348,18 @@ def finalize_options(self) -> None: def build_extensions(self) -> None: self._check_abi() for extension in self.extensions: + # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when + # extra_compile_args is a dict. Otherwise, default torch flags do + # not get passed. Necessary when only one of 'cxx' and 'nvcc' is + # passed to extra_compile_args in CUDAExtension, i.e. + # CUDAExtension(..., extra_compile_args={'cxx': [...]}) + # or + # CUDAExtension(..., extra_compile_args={'nvcc': [...]}) + if isinstance(extension.extra_compile_args, dict): + for ext in ['cxx', 'nvcc']: + if ext not in extension.extra_compile_args: + extension.extra_compile_args[ext] = [] + self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') self._define_torch_extension_name(extension) self._add_gnu_cpp_abi_flag(extension)