diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py
index e4e6767a444ead9..81479d9812042d6 100644
--- a/.circleci/cimodel/data/pytorch_build_data.py
+++ b/.circleci/cimodel/data/pytorch_build_data.py
@@ -5,7 +5,9 @@
("xenial", [
("rocm", [
("3.5.1", [
- X("3.6"),
+ ("3.6", [
+ ('build_only', [XImportant(True)]),
+ ]),
]),
]),
("gcc", [
diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py
index 3201d4581c9a96b..0df2c8ed920b6b3 100644
--- a/.circleci/cimodel/data/pytorch_build_definitions.py
+++ b/.circleci/cimodel/data/pytorch_build_definitions.py
@@ -307,7 +307,7 @@ def instantiate_configs():
parallel_backend = fc.find_prop("parallel_backend") or None
build_only = fc.find_prop("build_only") or False
is_coverage = fc.find_prop("is_coverage") or False
- if build_only and restrict_phases is None:
+ if build_only:
restrict_phases = ["build"]
if is_coverage and restrict_phases is None:
restrict_phases = ["build", "coverage_test"]
diff --git a/.circleci/cimodel/data/simple/mobile_definitions.py b/.circleci/cimodel/data/simple/mobile_definitions.py
index af48010937d84ee..9c326b779eba230 100644
--- a/.circleci/cimodel/data/simple/mobile_definitions.py
+++ b/.circleci/cimodel/data/simple/mobile_definitions.py
@@ -57,11 +57,6 @@ def gen_tree(self):
[DOCKER_REQUIREMENT_ASAN],
["build"]
),
- MobileJob(
- DOCKER_IMAGE_ASAN,
- [DOCKER_REQUIREMENT_ASAN],
- ["custom", "build", "static"]
- ),
# Use LLVM-DEV toolchain in android-ndk-r19c docker image
MobileJob(
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 79176e5e19c9421..b9be825f5d06797 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -6120,54 +6120,9 @@ workflows:
name: pytorch_linux_xenial_rocm3_5_1_py3_6_build
requires:
- "docker-pytorch-linux-xenial-rocm3.5.1-py3.6"
- filters:
- branches:
- only:
- - master
- - /ci-all\/.*/
- - /release\/.*/
build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6"
resource_class: xlarge
- - pytorch_linux_test:
- name: pytorch_linux_xenial_rocm3_5_1_py3_6_test1
- requires:
- - pytorch_linux_xenial_rocm3_5_1_py3_6_build
- filters:
- branches:
- only:
- - master
- - /ci-all\/.*/
- - /release\/.*/
- build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-test1"
- docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6"
- resource_class: pytorch/amd-gpu
- - pytorch_linux_test:
- name: pytorch_linux_xenial_rocm3_5_1_py3_6_test2
- requires:
- - pytorch_linux_xenial_rocm3_5_1_py3_6_build
- filters:
- branches:
- only:
- - master
- - /ci-all\/.*/
- - /release\/.*/
- build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-test2"
- docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6"
- resource_class: pytorch/amd-gpu
- - pytorch_linux_test:
- name: pytorch_linux_xenial_rocm3_5_1_py3_6_caffe2_test
- requires:
- - pytorch_linux_xenial_rocm3_5_1_py3_6_build
- filters:
- branches:
- only:
- - master
- - /ci-all\/.*/
- - /release\/.*/
- build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-caffe2_test"
- docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6"
- resource_class: pytorch/amd-gpu
- pytorch_linux_build:
name: pytorch_linux_xenial_py3_6_gcc5_4_build
requires:
@@ -6652,13 +6607,6 @@ workflows:
name: pytorch_linux_xenial_py3_clang5_mobile_build
requires:
- docker-pytorch-linux-xenial-py3-clang5-asan
- - pytorch_linux_build:
- build_environment: pytorch-linux-xenial-py3-clang5-mobile-custom-build-static
- build_only: "1"
- docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan
- name: pytorch_linux_xenial_py3_clang5_mobile_custom_build_static
- requires:
- - docker-pytorch-linux-xenial-py3-clang5-asan
- pytorch_linux_build:
build_environment: pytorch-linux-xenial-py3-clang5-mobile-custom-build-dynamic
build_only: "1"
diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh
index d310a125a98ec95..c2057901b2d7bda 100755
--- a/.circleci/docker/build.sh
+++ b/.circleci/docker/build.sh
@@ -160,7 +160,6 @@ case "$image" in
KATEX=yes
;;
pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7)
- UBUNTU_VERSION=16.04-rc
CUDA_VERSION=11.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.6
@@ -230,7 +229,6 @@ case "$image" in
VISION=yes
;;
pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9)
- UBUNTU_VERSION=18.04-rc
CUDA_VERSION=11.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.6
@@ -241,7 +239,6 @@ case "$image" in
KATEX=yes
;;
pytorch-linux-bionic-cuda11.0-cudnn8-py3.8-gcc9)
- UBUNTU_VERSION=18.04-rc
CUDA_VERSION=11.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.8
diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh
index b25ab843006594e..2a1c2bd0ea8fad1 100755
--- a/.circleci/docker/common/install_conda.sh
+++ b/.circleci/docker/common/install_conda.sh
@@ -86,6 +86,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
conda_install magma-cuda101 -c pytorch
elif [[ "$CUDA_VERSION" == 10.2* ]]; then
conda_install magma-cuda102 -c pytorch
+ elif [[ "$CUDA_VERSION" == 11.0* ]]; then
+ conda_install magma-cuda110 -c pytorch
fi
# TODO: This isn't working atm
diff --git a/.circleci/scripts/binary_upload.sh b/.circleci/scripts/binary_upload.sh
index 139f5e590c8ea8b..9282674d3a47584 100755
--- a/.circleci/scripts/binary_upload.sh
+++ b/.circleci/scripts/binary_upload.sh
@@ -30,7 +30,7 @@ do_backup() {
(
pushd /tmp/workspace
set -x
- ${AWS_S3_CP} --recursive . "${BACKUP_BUCKET}/${CIRCLE_TAG}/${backup_dir}"
+ ${AWS_S3_CP} --recursive . "${BACKUP_BUCKET}/${CIRCLE_TAG}/${backup_dir}/"
)
}
@@ -52,7 +52,7 @@ s3_upload() {
local pkg_type
extension="$1"
pkg_type="$2"
- s3_dir="${UPLOAD_BUCKET}/${pkg_type}/${UPLOAD_CHANNEL}/${UPLOAD_SUBFOLDER}"
+ s3_dir="${UPLOAD_BUCKET}/${pkg_type}/${UPLOAD_CHANNEL}/${UPLOAD_SUBFOLDER}/"
(
for pkg in ${PKG_DIR}/*.${extension}; do
(
diff --git a/.jenkins/pytorch/build-mobile.sh b/.jenkins/pytorch/build-mobile.sh
index 03b836b76c3f67f..b1234f2728130e5 100755
--- a/.jenkins/pytorch/build-mobile.sh
+++ b/.jenkins/pytorch/build-mobile.sh
@@ -22,9 +22,7 @@ retry pip install --pre torch torchvision \
# Run end-to-end process of building mobile library, linking into the predictor
# binary, and running forward pass with a real model.
-if [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-static* ]]; then
- TEST_CUSTOM_BUILD_STATIC=1 test/mobile/custom_build/build.sh
-elif [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-dynamic* ]]; then
+if [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-dynamic* ]]; then
export LLVM_DIR="$(llvm-config-5.0 --prefix)"
echo "LLVM_DIR: ${LLVM_DIR}"
TEST_CUSTOM_BUILD_DYNAMIC=1 test/mobile/custom_build/build.sh
diff --git a/BUILD.bazel b/BUILD.bazel
index da50ea2d4b8c424..f7be71ec624d1d2 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -2,7 +2,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_proto_library", "cc_test")
-load("//third_party:substitution.bzl", "template_rule")
+load("//third_party:substitution.bzl", "header_template_rule")
load("//:tools/build_variables.bzl", "torch_cpp_srcs", "libtorch_python_core_sources", "libtorch_core_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "jit_core_sources")
load("//tools/rules:cu.bzl", "cu_library")
load("//tools/config:defs.bzl", "if_cuda")
@@ -27,19 +27,18 @@ COMMON_COPTS = [
])
# c10
-template_rule(
+header_template_rule(
name = "cmake_macros_h",
src = "c10/macros/cmake_macros.h.in",
out = "c10/macros/cmake_macros.h",
substitutions = {
"cmakedefine": "define",
"#define FEATURE_TORCH_MOBILE": "/* #undef FEATURE_TORCH_MOBILE */",
- "#define USE_STATIC_DISPATCH": "/* #undef USE_STATIC_DISPATCH */",
"#define C10_USE_NUMA": "/* #undef C10_USE_NUMA */",
},
)
-template_rule(
+header_template_rule(
name = "cuda_cmake_macros_h",
src = "c10/cuda/impl/cuda_cmake_macros.h.in",
out = "c10/cuda/impl/cuda_cmake_macros.h",
@@ -58,13 +57,12 @@ cc_library(
"c10/macros/*.h",
"c10/util/*.h",
"c10/util/*.hpp",
- ]) + [
- "c10/macros/cmake_macros.h",
- "c10/cuda/impl/cuda_cmake_macros.h",
- ],
+ ]),
deps = [
"@com_github_gflags_gflags//:gflags",
"@com_github_glog//:glog",
+ ":cmake_macros_h",
+ ":cuda_cmake_macros_h",
],
)
@@ -531,7 +529,7 @@ filegroup(
],
)
-template_rule(
+header_template_rule(
name = "aten_src_ATen_config",
src = "aten/src/ATen/Config.h.in",
out = "aten/src/ATen/Config.h",
@@ -547,7 +545,7 @@ template_rule(
},
)
-template_rule(
+header_template_rule(
name = "aten_src_ATen_cuda_config",
src = "aten/src/ATen/cuda/CUDAConfig.h.in",
out = "aten/src/ATen/cuda/CUDAConfig.h",
@@ -558,7 +556,7 @@ template_rule(
},
)
-template_rule(
+header_template_rule(
name = "aten_src_TH_THGeneral",
src = "aten/src/TH/THGeneral.h.in",
out = "aten/src/TH/THGeneral.h",
@@ -570,7 +568,7 @@ template_rule(
},
)
-template_rule(
+header_template_rule(
name = "aten_src_THC_THCGeneral",
src = "aten/src/THC/THCGeneral.h.in",
out = "aten/src/THC/THCGeneral.h",
@@ -582,8 +580,6 @@ template_rule(
cc_library(
name = "aten_headers",
hdrs = [
- "aten/src/TH/THGeneral.h",
- "aten/src/THC/THCGeneral.h",
"torch/csrc/WindowsTorchApiMacro.h",
"torch/csrc/jit/frontend/function_schema_parser.h",
] + glob([
@@ -605,6 +601,8 @@ cc_library(
],
deps = [
":c10_headers",
+ ":aten_src_TH_THGeneral",
+ ":aten_src_THC_THCGeneral",
],
)
@@ -766,7 +764,7 @@ cc_proto_library(
deps = [":caffe2_proto_source"],
)
-template_rule(
+header_template_rule(
name = "caffe2_core_macros_h",
src = "caffe2/core/macros.h.in",
out = "caffe2/core/macros.h",
@@ -1586,7 +1584,6 @@ filegroup(
cc_library(
name = "caffe2_for_aten_headers",
hdrs = [
- "caffe2/core/macros.h",
"caffe2/core/common.h",
"caffe2/core/logging.h",
"caffe2/core/types.h",
@@ -1604,6 +1601,7 @@ cc_library(
deps = [
":c10_headers",
":caffe2_protos",
+ ":caffe2_core_macros_h",
],
)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a842fe3ba9ee7a9..fb7b306547bbeb9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -124,7 +124,6 @@ option(BUILD_PYTHON "Build Python binaries" ON)
option(BUILD_CAFFE2_OPS "Build Caffe2 operators" ON)
option(BUILD_SHARED_LIBS "Build libcaffe2.so" ON)
option(BUILD_CAFFE2_MOBILE "Build libcaffe2 for mobile (deprecating)" OFF)
-option(USE_STATIC_DISPATCH "Use static dispatch for ATen operators" OFF)
cmake_dependent_option(
CAFFE2_LINK_LOCAL_PROTOBUF "If set, build protobuf inside libcaffe2.so." ON
"BUILD_SHARED_LIBS AND BUILD_CUSTOM_PROTOBUF" OFF)
diff --git a/README.md b/README.md
index ce9d81c3d07caf0..6e1fcfdb828f833 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@ PyTorch is a Python package that provides two high-level features:
- Tensor computation (like NumPy) with strong GPU acceleration
- Deep neural networks built on a tape-based autograd system
-You can reuse your favorite Python packages such as NumPy, SciPy and Cython to extend PyTorch when needed.
+You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to extend PyTorch when needed.
- [More about PyTorch](#more-about-pytorch)
- [Installation](#installation)
@@ -45,7 +45,7 @@ At a granular level, PyTorch is a library that consists of the following compone
| [**torch.multiprocessing**](https://pytorch.org/docs/stable/multiprocessing.html) | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
| [**torch.utils**](https://pytorch.org/docs/stable/data.html) | DataLoader and other utility functions for convenience |
-Usually PyTorch is used either as:
+Usually, PyTorch is used either as:
- a replacement for NumPy to use the power of GPUs.
- a deep learning research platform that provides maximum flexibility and speed.
@@ -58,7 +58,7 @@ If you use NumPy, then you have used Tensors (a.k.a. ndarray).
![Tensor illustration](./docs/source/_static/img/tensor_illustration.png)
-PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerates the
+PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the
computation by a huge amount.
We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs
@@ -69,8 +69,8 @@ And they are fast!
PyTorch has a unique way of building neural networks: using and replaying a tape recorder.
-Most frameworks such as TensorFlow, Theano, Caffe and CNTK have a static view of the world.
-One has to build a neural network, and reuse the same structure again and again.
+Most frameworks such as TensorFlow, Theano, Caffe, and CNTK have a static view of the world.
+One has to build a neural network and reuse the same structure again and again.
Changing the way the network behaves means that one has to start from scratch.
With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to
@@ -96,9 +96,9 @@ Our goal is to not reinvent the wheel where appropriate.
### Imperative Experiences
-PyTorch is designed to be intuitive, linear in thought and easy to use.
+PyTorch is designed to be intuitive, linear in thought, and easy to use.
When you execute a line of code, it gets executed. There isn't an asynchronous view of the world.
-When you drop into a debugger, or receive error messages and stack traces, understanding them is straightforward.
+When you drop into a debugger or receive error messages and stack traces, understanding them is straightforward.
The stack trace points to exactly where your code was defined.
We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.
@@ -220,10 +220,10 @@ If the version of Visual Studio 2017 is higher than 15.6, installing of "VC++ 20
NVTX is a part of CUDA distributive, where it is called "Nsight Compute". To install it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox.
Be sure that CUDA with Nsight Compute is installed after Visual Studio 2017.
-Currently VS 2017, VS 2019 and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise it will use VS 2017.
+Currently, VS 2017, VS 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017.
If Ninja is selected as the generator, the latest MSVC which is newer than VS 2015 (14.0) will get selected as the underlying toolchain. If you use CMake <= 3.14.2 and has VS 2019 installed, then even if you specify VS 2017 as the generator, VS 2019 will get selected as the generator.
-CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like `nvcc fatal : Host compiler targets unsupported OS`. For this kind of problem, please install the corresponding VS toolchain in the table below and then you can either specify the toolset during activation (recommended) or set `CUDAHOSTCXX` to override the cuda host compiler (not recommended if there are big version differences).
+CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like `nvcc fatal : Host compiler targets unsupported OS`. For this kind of problem, please install the corresponding VS toolchain in the table below, and then you can either specify the toolset during activation (recommended) or set `CUDAHOSTCXX` to override the Cuda host compiler (not recommended if there are big version differences).
| CUDA version | Newest supported VS version |
| ------------ | ------------------------------------------------------- |
@@ -246,7 +246,7 @@ set CMAKE_GENERATOR_TOOLSET_VERSION=14.11
set DISTUTILS_USE_SDK=1
for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION%
-:: [Optional] If you want to override the cuda host compiler
+:: [Optional] If you want to override the Cuda host compiler
set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Tools\MSVC\14.11.25503\bin\HostX64\x64\cl.exe
python setup.py install
@@ -291,7 +291,7 @@ should increase shared memory size either with `--ipc=host` or `--shm-size` comm
**NOTE:** Must be built with a docker version > 18.06
-The `Dockerfile` is supplied to build images with cuda support and cudnn v7.
+The `Dockerfile` is supplied to build images with Cuda support and cuDNN v7.
You can pass `PYTHON_VERSION=x.y` make variable to specify which Python version is to be used by Miniconda, or leave it
unset to use the default.
```bash
@@ -319,7 +319,7 @@ on [our website](https://pytorch.org/previous-versions).
## Getting Started
-Three pointers to get you started:
+Three-pointers to get you started:
- [Tutorials: get you started with understanding and using PyTorch](https://pytorch.org/tutorials/)
- [Examples: easy to understand pytorch code across all domains](https://github.com/pytorch/examples)
- [The API Reference](https://pytorch.org/docs/)
@@ -341,31 +341,31 @@ Three pointers to get you started:
## Communication
* forums: discuss implementations, research, etc. https://discuss.pytorch.org
* GitHub issues: bug reports, feature requests, install issues, RFCs, thoughts, etc.
-* Slack: The [PyTorch Slack](https://pytorch.slack.com/) hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration etc. If you are a beginner looking for help, the primary medium is [PyTorch Forums](https://discuss.pytorch.org). If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1
-* newsletter: no-noise, one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv
+* Slack: The [PyTorch Slack](https://pytorch.slack.com/) hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration, etc. If you are a beginner looking for help, the primary medium is [PyTorch Forums](https://discuss.pytorch.org). If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1
+* newsletter: no-noise, a one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv
* Facebook page: important announcements about PyTorch. https://www.facebook.com/pytorch
* for brand guidelines, please visit our website at [pytorch.org](https://pytorch.org/)
## Releases and Contributing
-PyTorch has a 90 day release cycle (major releases). Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues).
+PyTorch has a 90-day release cycle (major releases). Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues).
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
-If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us.
-Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
+If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us.
+Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of.
To learn more about making a contribution to Pytorch, please see our [Contribution page](CONTRIBUTING.md).
## The Team
-PyTorch is a community driven project with several skillful engineers and researchers contributing to it.
+PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by [Adam Paszke](https://apaszke.github.io/), [Sam Gross](https://github.com/colesbury), [Soumith Chintala](http://soumith.ch) and [Gregory Chanan](https://github.com/gchanan) with major contributions coming from hundreds of talented individuals in various forms and means.
A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.
-Note: this project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor in the Torch community and has helped with many things Torch and PyTorch.
+Note: this project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.
## License
-PyTorch is BSD-style licensed, as found in the [LICENSE](LICENSE) file.
+PyTorch is a BSD-style licensed, as found in the [LICENSE](LICENSE) file.
diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp
index 7869db6006bc365..aaafc341334c4e6 100644
--- a/aten/src/ATen/BatchedFallback.cpp
+++ b/aten/src/ATen/BatchedFallback.cpp
@@ -123,12 +123,14 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
// We assume that torch::jit::Stack is backed by vector for
// simplicity. When that is not the case, this code should be updated.
const auto& argument = (*stack)[arguments_begin + arg_idx];
- if (arg_idx != *batched_tensor_inputs_pos_iter) {
+ if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
+ || arg_idx != *batched_tensor_inputs_pos_iter) {
// argument isn't a BatchedTensor
torch::jit::push(stack, argument);
continue;
}
// argument is a BatchedTensor
+ TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
const auto& physical_view_for_argument = *input_physical_views_iter;
torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
batched_tensor_inputs_pos_iter++;
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index e97b3d368d1c600..33ac088b70a7c3c 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -64,7 +64,7 @@ file(GLOB native_cpp "native/*.cpp")
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
file(GLOB vulkan_cpp "vulkan/*.cpp")
-file(GLOB native_vulkan_cpp "native/vulkan/*.cpp")
+file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp")
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
file(GLOB native_quantized_cpp
"native/quantized/*.cpp"
diff --git a/aten/src/ATen/MemoryOverlap.h b/aten/src/ATen/MemoryOverlap.h
index 7af936b60c1b7fa..67f63a64668c35d 100644
--- a/aten/src/ATen/MemoryOverlap.h
+++ b/aten/src/ATen/MemoryOverlap.h
@@ -24,7 +24,7 @@ CAFFE2_API void assert_no_internal_overlap(TensorImpl* t);
CAFFE2_API MemOverlapStatus get_overlap_status(const Tensor& a, const Tensor& b);
CAFFE2_API MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b);
-void assert_no_partial_overlap(const Tensor& a, const Tensor& b);
+CAFFE2_API void assert_no_partial_overlap(const Tensor& a, const Tensor& b);
void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b);
}
diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp
index 74c4e0b69cf3169..c7180f9d558ecf0 100644
--- a/aten/src/ATen/autocast_mode.cpp
+++ b/aten/src/ATen/autocast_mode.cpp
@@ -202,8 +202,6 @@ Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::op
"safe to autocast.");
}
-
-#ifndef USE_STATIC_DISPATCH
namespace {
/*****************************************************************************************************************
This section performs load-time registration for autocast wrappers.
@@ -378,7 +376,6 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
}
}
-#endif
} // namespace autocast
} // namespace at
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index cd6fa93515da251..57a36b135eaf429 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -23,8 +23,6 @@ _(aten, _abs) \
_(aten, _addmv) \
_(aten, _addr) \
_(aten, _arange) \
-_(aten, _asinh) \
-_(aten, _atanh) \
_(aten, _argmax) \
_(aten, _argmin) \
_(aten, _baddbmm_mkl) \
@@ -656,8 +654,6 @@ _(aten, stft) \
_(aten, storage_offset) \
_(aten, stride) \
_(aten, strides) \
-_(aten, sub) \
-_(aten, sub_) \
_(aten, rsub) \
_(aten, sum) \
_(aten, sum_to_size) \
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index a38044b1a304e85..9738bfaabb8b7be 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -95,6 +95,7 @@ namespace c10 {
_(prim, Guard) \
_(prim, BailOut) \
_(prim, TypeCheck) \
+ _(prim, FallbackGraph) \
_(prim, FusedConcat) \
_(prim, ConstantChunk) \
_(prim, MMTreeReduce) \
@@ -130,6 +131,7 @@ namespace c10 {
_(prim, GetAttr) \
_(prim, HasAttr) \
_(prim, profile) \
+ _(prim, profile_optional) \
_(prim, AddStatValue) \
_(prim, TimePoint) \
_(prim, CallFunction) \
@@ -156,16 +158,25 @@ namespace c10 {
_(aten, asin_) \
_(aten, arcsin) \
_(aten, arcsin_) \
+ _(aten, asinh) \
+ _(aten, asinh_) \
+ _(aten, arcsinh) \
+ _(aten, arcsinh_) \
_(aten, atan) \
_(aten, atan_) \
_(aten, arctan) \
_(aten, arctan_) \
+ _(aten, atanh) \
+ _(aten, atanh_) \
+ _(aten, arctanh) \
+ _(aten, arctanh_) \
_(aten, clamp) \
_(aten, clamp_) \
_(aten, clip) \
_(aten, clip_) \
_(aten, det) \
_(aten, linalg_det) \
+ _(aten, linalg_norm) \
_(aten, append) \
_(aten, item) \
_(aten, format) \
@@ -203,6 +214,10 @@ namespace c10 {
_(aten, list) \
_(aten, wait) \
_(aten, save) \
+ _(aten, sub) \
+ _(aten, sub_) \
+ _(aten, subtract) \
+ _(aten, subtract_) \
_(aten, keys) \
_(aten, ord) \
_(aten, chr) \
@@ -293,6 +308,8 @@ namespace c10 {
_(attr, inplace) \
_(attr, input_as_shape) \
_(attr, is_zero) \
+ _(attr, num_none) \
+ _(attr, num_present) \
_(attr, perm) \
_(attr, sizes) \
_(attr, starts) \
diff --git a/aten/src/ATen/core/op_registration/op_whitelist.h b/aten/src/ATen/core/op_registration/op_whitelist.h
index 92d71e2628c89e3..cf7f09bb0f9a3e7 100644
--- a/aten/src/ATen/core/op_registration/op_whitelist.h
+++ b/aten/src/ATen/core/op_registration/op_whitelist.h
@@ -60,9 +60,13 @@ constexpr bool op_whitelist_check(string_view op_name) {
#else
return op_whitelist_contains(
C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
- // Strip overload name (as whitelist doesn't contain overloads)
- OperatorNameView::parse(op_name).name
- );
+ // This function is majorly used for mobile selective build with
+ // root operators, where the overload is included in the whitelist.
+ op_name);
+ // // Strip overload name (as whitelist doesn't contain overloads)
+ // // Another function based on this may be added when there's usage
+ // // on op names without overload.
+ // OperatorNameView::parse(op_name).name);
#endif
}
@@ -76,6 +80,12 @@ constexpr bool schema_whitelist_check(string_view schema) {
#endif
}
+// schema_whitelist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
+// Add this API to pass arbitrary whitelist.
+constexpr bool op_whitelist_contains_name_in_schema(string_view whitelist, string_view schema) {
+ return op_whitelist_contains(whitelist, schema.substr(0, schema.find("(")));
+}
+
// Returns true iff the given dispatch key is on the whitelist
// and should be registered. When we turn this on, the list of valid
// mobile dispatch keys is hard coded (but you need to make sure
diff --git a/aten/src/ATen/cudnn/AutocastRNN.cpp b/aten/src/ATen/cudnn/AutocastRNN.cpp
index 4a900f1309b6a50..31e1a26e8fb7eca 100644
--- a/aten/src/ATen/cudnn/AutocastRNN.cpp
+++ b/aten/src/ATen/cudnn/AutocastRNN.cpp
@@ -104,14 +104,12 @@ _cudnn_rnn_cast_reflatten(const Tensor & input,
#endif // AT_CUDNN_ENABLED()
}
-#ifndef USE_STATIC_DISPATCH
namespace {
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
m.impl("_cudnn_rnn",
TORCH_FN((&at::autocast::_cudnn_rnn_cast_reflatten)));
}
} // anonymous namespace
-#endif
} // namespace autocast
} // namespace at
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index e26bb3941b3a7e6..f996e73e5d9ccdb 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -146,14 +146,10 @@ def TypedDict(name, attrs, total=True): # type: ignore
// ${schema_string}
${return_type} Tensor::${api_name}(${method_formals}) const {
-#ifdef USE_STATIC_DISPATCH
- ${static_dispatch_method_body}
-#else
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::${operator_name}", "${overload_name}")
.typed<${tensor_method_cpp_signature}>();
return op.call(${tensor_method_actuals});
-#endif
}
""")
@@ -172,45 +168,13 @@ def TypedDict(name, attrs, total=True): # type: ignore
// ${schema_string}
${return_type} ${api_name}(${formals}) {
-#ifdef USE_STATIC_DISPATCH
- ${static_dispatch_function_body}
-#else
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::${operator_name}", "${overload_name}")
.typed<${function_cpp_signature}>();
return op.call(${function_actuals});
-#endif
-}
-""")
-
-# In order to rely on the linker to strip unused ops, it requires us to dispatch statically
-# in Functions.h and TensorMethods.cpp.
-#
-# NB: The default body also needs to apply a variable guard, as in some
-# situations what we think is a default body actually does have an
-# explicit derivative, and thereby would have gotten unwrapped by
-# the time you get to the implementation.
-STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\
-at::AutoNonVariableTypeMode _var_guard(true);
-${return_call} TypeDefault::${type_wrapper_name}(${actuals});
-""")
-
-STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\
-at::AutoNonVariableTypeMode _var_guard(true);
-${dispatch_key_init}
-switch (dispatchKeyToBackend(${dispatch_key_var_name})) {
- ${static_dispatch_function_cases}
- default:
- AT_ERROR("${api_name} not implemented for ", at::toString(${dispatch_key_var_name}));
}
""")
-STATIC_DISPATCH_FUNCTION_SWITCH_CASE = CodeTemplate("""\
-case Backend::${backend}:
- ${return_call} ${backend}Type::${type_wrapper_name}(${actuals});
- break;
-""")
-
IFDEF_BLOCK = CodeTemplate("""\
#ifdef ${ifdef_guard}
${content}
@@ -246,10 +210,6 @@ def TypedDict(name, attrs, total=True): # type: ignore
('ComplexDouble', 'ComplexDouble', 'ComplexDouble', False),
]
-static_dispatch_backends = ['CPU', 'QuantizedCPU', 'Vulkan']
-static_dispatch_backends_ifdef_guard = {'Vulkan' : 'USE_VULKAN'}
-
-
class NYIError(Exception):
"""Indicates we don't support this declaration yet"""
@@ -1136,44 +1096,6 @@ def swizzle_self(f): # blegh
method_actuals = maybe_unwrap_optional_tensors(option, formals, option['method_actuals'])
- if isinstance(type_method_dispatch, dict):
- static_dispatch_function_cases = []
- # NB: As this code is currently written, there will NEVER be
- # a backend generated for variable dispatch. There is nothing
- # stopping us from actually implementing this, however, if you
- # really wanted variable on mobile, there's nothing stopping
- # you from implementing this (however, you would have an
- # annoying phase problem, since code generation for variable
- # happens in tools/ which happens later than here.)
- #
- # If you pass in a variable to the dispatch, and variable is
- # enabled, this switch will fail. This is intentional: you
- # probably need to disable variable globally in the mobile
- # calling code.
- for backend in static_dispatch_backends:
- if backend in type_method_dispatch:
- static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
- option,
- backend=backend,
- backend_function=type_method_dispatch[backend],
- actuals=method_actuals)
- if (backend in static_dispatch_backends_ifdef_guard):
- static_dispatch_function_cases.append(IFDEF_BLOCK.substitute(
- option,
- ifdef_guard=static_dispatch_backends_ifdef_guard[backend],
- content=static_dispatch_function_case))
- else:
- static_dispatch_function_cases.append(static_dispatch_function_case)
-
- static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
- option,
- dispatch_key_var_name=dispatch_key_var_name,
- dispatch_key_init=dispatch_key_init,
- static_dispatch_function_cases=static_dispatch_function_cases)
- else:
- static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute(
- option, actuals=method_actuals)
-
# See NOTE[UnboxedOnly]
if option['use_c10_dispatcher'] == 'full':
tensor_method_actuals = option['schema_order_method_actuals']
@@ -1184,13 +1106,12 @@ def swizzle_self(f): # blegh
tensor_method_cpp_signature = option['cpp_signature']
method_definition = TENSOR_METHOD_DEFINITION.substitute(
- option, static_dispatch_method_body=static_dispatch_method_body,
+ option,
tensor_method_actuals=tensor_method_actuals,
tensor_method_cpp_signature=tensor_method_cpp_signature
)
return FunctionCode(
- declaration=TENSOR_METHOD_DECLARATION.substitute(
- option, static_dispatch_method_body=static_dispatch_method_body),
+ declaration=TENSOR_METHOD_DECLARATION.substitute(option),
definition=method_definition)
def gen_namespace_function(option, multidispatch_formals):
@@ -1204,31 +1125,6 @@ def gen_namespace_function(option, multidispatch_formals):
actuals = maybe_unwrap_optional_tensors(option, formals, option['actuals'])
- if isinstance(type_method_dispatch, dict):
- static_dispatch_function_cases = []
- for backend in static_dispatch_backends:
- if backend in type_method_dispatch:
- static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
- option,
- backend=backend,
- backend_function=type_method_dispatch[backend],
- actuals=actuals)
- if (backend in static_dispatch_backends_ifdef_guard):
- static_dispatch_function_cases.append(IFDEF_BLOCK.substitute(
- option,
- ifdef_guard=static_dispatch_backends_ifdef_guard[backend],
- content=static_dispatch_function_case))
- else:
- static_dispatch_function_cases.append(static_dispatch_function_case)
- static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
- option,
- dispatch_key_var_name=dispatch_key_var_name,
- dispatch_key_init=dispatch_key_init,
- static_dispatch_function_cases=static_dispatch_function_cases)
- else:
- static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute(
- option, actuals=actuals)
-
# See NOTE[UnboxedOnly]
if option['use_c10_dispatcher'] == 'full':
function_actuals = option['schema_order_actuals']
@@ -1239,7 +1135,7 @@ def gen_namespace_function(option, multidispatch_formals):
function_cpp_signature = option['cpp_signature']
fn_definition = FUNCTION_DEFINITION.substitute(
- option, static_dispatch_function_body=static_dispatch_function_body,
+ option,
function_actuals=function_actuals,
function_cpp_signature=function_cpp_signature)
diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp
index 0e301df3c9afdef..a43f16465db84e8 100644
--- a/aten/src/ATen/native/Activation.cpp
+++ b/aten/src/ATen/native/Activation.cpp
@@ -372,7 +372,16 @@ static Tensor threshold_out(
Scalar value,
const Tensor& other) {
Tensor result = opt_result.value_or(Tensor());
- auto iter = TensorIterator::binary_op(result, self, other);
+ auto iter = TensorIteratorConfig()
+ .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
+ .add_output(result)
+ .add_input(self)
+ .add_input(other)
+ .allow_cpu_scalars(true)
+ .promote_inputs_to_common_dtype(true)
+ .cast_common_dtype_to_outputs(true)
+ .enforce_safe_casting_to_output(true)
+ .build();
threshold_stub(iter.device_type(), iter, threshold, value);
return iter.output();
}
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 50a758d8f3b240f..bccba591a529cc3 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -48,9 +48,14 @@ DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
+static Tensor wrapped_scalar_tensor(Scalar scalar) {
+ auto tensor = scalar_to_tensor(scalar);
+ tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
+ return tensor;
+}
+
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter.dtype(), alpha);
add_stub(iter.device_type(), iter, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype());
@@ -71,8 +76,7 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
Tensor& add_relu_impl(
Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
Scalar min_val;
Scalar max_val;
if (self.dtype() == at::kInt) {
@@ -124,8 +128,7 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
"Use true_divide or floor_divide (// in Python) instead.");
}
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);
return result;
}
@@ -150,8 +153,7 @@ Tensor& div_(Tensor& self, const Tensor& other) {
}
Tensor& remainder_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
remainder_stub(iter.device_type(), iter);
return result;
}
@@ -212,8 +214,7 @@ Tensor& true_divide_(Tensor& self, const Tensor& divisor) {
}
Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);
if (result.is_floating_point()) {
@@ -242,8 +243,7 @@ Tensor& floor_divide_(Tensor& self, const Tensor& other) {
}
Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter.device_type(), iter);
return result;
}
@@ -261,8 +261,7 @@ Tensor& mul_(Tensor& self, const Tensor& other) {
Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
sub_check(self, other);
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter.dtype(), alpha);
sub_stub(iter.device_type(), iter, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype());
@@ -282,6 +281,35 @@ Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::sub_out(self, self, other, alpha);
}
+Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
+ return native::sub(self, wrapped_scalar_tensor(other), alpha);
+}
+
+Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
+ return native::sub_(self, wrapped_scalar_tensor(other), alpha);
+}
+
+// subtract, alias for sub
+Tensor& subtract_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
+ return at::sub_out(result, self, other, alpha);
+}
+
+Tensor subtract(const Tensor& self, const Tensor& other, Scalar alpha) {
+ return self.sub(other, alpha);
+}
+
+Tensor& subtract_(Tensor& self, const Tensor& other, Scalar alpha) {
+ return self.sub_(other, alpha);
+}
+
+Tensor subtract(const Tensor& self, Scalar other, Scalar alpha) {
+ return self.sub(other, alpha);
+}
+
+Tensor& subtract_(Tensor& self, Scalar other, Scalar alpha) {
+ return self.sub_(other, alpha);
+}
+
Tensor& sigmoid_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) {
auto iter = TensorIterator::binary_op(result, grad_output, output);
sigmoid_backward_stub(iter.device_type(), iter);
@@ -353,12 +381,6 @@ Tensor& atan2_(Tensor& self, const Tensor& other) {
// types (int, float, etc.) to Tensor (only to Scalar). They're not exposed
// to Python.
-static Tensor wrapped_scalar_tensor(Scalar scalar) {
- auto tensor = scalar_to_tensor(scalar);
- tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
- return tensor;
-}
-
static void check_convert(Scalar scalar, ScalarType scalarType) {
// Validate that is possible to convert scalar to tensor dtype without overflow
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
@@ -425,21 +447,12 @@ Tensor& mul_(Tensor& self, Scalar other) {
return native::mul_(self, wrapped_scalar_tensor(other));
}
-Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
- return native::sub(self, wrapped_scalar_tensor(other), alpha);
-}
-
-Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
- return native::sub_(self, wrapped_scalar_tensor(other), alpha);
-}
-
Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
bitwise_and_stub(iter.device_type(), iter);
return result;
}
@@ -485,8 +498,7 @@ Tensor& __iand__(Tensor& self, Scalar other) {
}
Tensor& bitwise_or_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
bitwise_or_stub(iter.device_type(), iter);
return result;
}
@@ -532,8 +544,7 @@ Tensor& __ior__(Tensor& self, Scalar other) {
}
Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
bitwise_xor_stub(iter.device_type(), iter);
return result;
}
@@ -644,7 +655,7 @@ Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& othe
check_convert(self.item(), other.scalar_type());
}
}
- auto iter = TensorIterator::comparison_op(result, self, other, /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::comparison_op(result, self, other);
stub(iter.device_type(), iter);
return result;
}
@@ -752,8 +763,7 @@ Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, o
Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
maximum_stub(iter.device_type(), iter);
return result;
}
@@ -779,8 +789,7 @@ Tensor max(const Tensor& self, const Tensor& other) {
Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs.");
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
minimum_stub(iter.device_type(), iter);
return result;
}
@@ -812,16 +821,14 @@ Tensor& floor_divide_(Tensor& self, Scalar other) {
}
Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_stub(iter.device_type(), iter);
return result;
}
Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) {
- auto iter = TensorIterator::unary_op(result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_scalar_stub(iter.device_type(), iter, other);
return result;
@@ -846,7 +853,7 @@ Tensor& fmod_(Tensor& self, Scalar other) {
}
Tensor& logaddexp_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
logaddexp_stub(iter.device_type(), iter);
return result;
}
@@ -857,7 +864,7 @@ Tensor logaddexp(const Tensor& self, const Tensor& other) {
}
Tensor& logaddexp2_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, self, other);
logaddexp2_stub(iter.device_type(), iter);
return result;
}
@@ -868,7 +875,7 @@ Tensor logaddexp2(const Tensor& self, const Tensor& other) {
}
Tensor& gcd_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/ true);
+ auto iter = TensorIterator::binary_op(result, self, other);
gcd_stub(iter.device_type(), iter);
return result;
}
@@ -883,7 +890,7 @@ Tensor& gcd_(Tensor& self, const Tensor& other) {
}
Tensor& lcm_out(Tensor& result, const Tensor& self, const Tensor& other) {
- auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/ true);
+ auto iter = TensorIterator::binary_op(result, self, other);
lcm_stub(iter.device_type(), iter);
return result;
}
diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp
index fc3c7d637a3ee89..606b381ca83b5d4 100644
--- a/aten/src/ATen/native/Blas.cpp
+++ b/aten/src/ATen/native/Blas.cpp
@@ -54,6 +54,8 @@ Tensor &addmv_out(Tensor& result, const Tensor &self, const Tensor &mat, const T
"size mismatch, get ", self_.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
if (mat.numel() == 0) {
+ // By definition, when beta==0, values in self should be ignored. nans and infs
+ // should not propagate
if (beta.toDouble() == 0.0) {
result.zero_();
} else {
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 158c8d9d8aa035a..212dd15eef33587 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -268,7 +268,8 @@ auto ConvParams::use_xnnpack(
padding,
stride,
dilation,
- groups);
+ groups,
+ transposed);
}
#endif
return false;
diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp
index 736c65b6e8bfc7e..d02dc781ed19a55 100644
--- a/aten/src/ATen/native/Fill.cpp
+++ b/aten/src/ATen/native/Fill.cpp
@@ -29,7 +29,12 @@ Tensor& fill_out(Tensor& self, Scalar value) {
fill_fast(self, value);});
return self;
}
- auto iter = TensorIterator::nullary_op(self);
+ auto iter = TensorIteratorConfig()
+ .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay
+ .check_all_same_dtype(false)
+ .add_output(self)
+ .resize_outputs(false)
+ .build();
fill_stub(iter.device_type(), iter, value);
return self;
}
diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp
index dcf74e89a92f01e..7caa6218b63aa6a 100644
--- a/aten/src/ATen/native/ForeachOpsKernels.cpp
+++ b/aten/src/ATen/native/ForeachOpsKernels.cpp
@@ -1,14 +1,24 @@
#include
+#include
+
namespace at { namespace native {
-std::vector foreach_add_scalar_kernel_fallback(TensorList tensors, Scalar scalar) {
- TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
+std::vector foreach_tensor_add_scalar_kernel_slow(TensorList tensors, Scalar scalar) {
+ verify_list(tensors);
std::vector result;
- for (int i = 0; i < tensors.size(); i++) {
- auto temp = tensors[i].add(scalar);
- result.emplace_back(temp);
+ for (const auto& t : tensors) {
+ result.emplace_back(t.add(scalar));
}
return result;
}
+
+void foreach_tensor_add_scalar_kernel_slow_(TensorList tensors, Scalar scalar) {
+ verify_list(tensors);
+
+ for (auto& t : tensors) {
+ t.add_(scalar);
+ }
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h
new file mode 100644
index 000000000000000..1afab4c6ba2860b
--- /dev/null
+++ b/aten/src/ATen/native/ForeachUtils.h
@@ -0,0 +1,58 @@
+#pragma once
+#include
+
+namespace at {
+namespace native {
+namespace {
+
+void verify_list(TensorList tensors) {
+ TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
+ auto expected_dtype = tensors[0].dtype();
+
+ for (auto t : tensors) {
+ TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
+ }
+}
+
+// To go via 'fast' path, several conditions must be satisfied
+// - All tensors must have strided layout
+// - All tensors must be non-overlapping and dense
+// - Resulting tensor must have the same dtype as the input one
+bool check_fast_route(TensorList tensors, Scalar scalar) {
+ TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
+ auto expected_device = tensors[0].device();
+
+ for (auto t : tensors) {
+ if (t.device() != expected_device) {
+ return false;
+ }
+
+ if (t.layout() != at::kStrided) {
+ return false;
+ }
+
+ if (!t.is_non_overlapping_and_dense()) {
+ return false;
+ }
+
+ // complex scalar + integral or boolean tensor will result in complex tensor
+ if (scalar.isComplex() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) {
+ return false;
+ }
+
+ // float scalar + integral or boolean tensor will result in float tensor
+ if (scalar.isFloatingPoint() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) {
+ return false;
+ }
+
+ // integral scalar + boolean tensor will result in integral tensor
+ if (scalar.isIntegral(/*includeBool*/ false) && t.dtype() == at::kBool) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // namespace
+}} // at::native
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 1f71a3cc99c8d13..a692ca0a800e7a0 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -4,6 +4,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -1284,10 +1285,13 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
if (dim.size() == 1 || dim.size() == 0) {
return at::norm(self, 2, dim, keepdim);
}
+ auto dim_ = dim.vec();
+ maybe_wrap_dims(dim_, self.dim());
+ TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
if (self.is_complex()){
- return at::sqrt(at::sum(at::real(self.conj() * self), dim, keepdim));
+ return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
} else {
- return at::sqrt(at::sum((self * self), dim, keepdim));
+ return at::sqrt(at::sum((self * self), dim_, keepdim));
}
}
@@ -1305,10 +1309,13 @@ Tensor &frobenius_norm_out(
if (dim.size() == 1 || dim.size() == 0) {
return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type());
}
+ auto dim_ = dim.vec();
+ maybe_wrap_dims(dim_, self.dim());
+ TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
if (self.is_complex()){
- return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim, keepdim));
+ return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim_, keepdim));
} else {
- return at::sqrt_out(result, at::sum((self * self), dim, keepdim));
+ return at::sqrt_out(result, at::sum((self * self), dim_, keepdim));
}
}
@@ -1342,8 +1349,10 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+ auto dim_ = dim.vec();
+ maybe_wrap_dims(dim_, self.dim());
- auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
+ auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permutation);
Tensor p = self.permute(permutation);
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
@@ -1360,19 +1369,243 @@ Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+ auto dim_ = dim.vec();
+ maybe_wrap_dims(dim_, self.dim());
- auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
+ auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permutation);
Tensor p = self.permute(permutation);
at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
if (keepdim) {
result.unsqueeze_(-1);
- result = result.permute(permutation_reverse);
+ Tensor result_ = result.permute(permutation_reverse);
+ result.set_(result_);
+ }
+ return result;
+}
+
+// Creates a vector of length ndim with values equal to its indices
+// (e.g. [0, 1, 2, ..., ndim-1])
+static std::vector make_dim_list(int64_t ndim) {
+ std::vector dim_list(ndim);
+ for (int64_t ind = 0; ind < ndim; ind++) {
+ dim_list[ind] = ind;
+ }
+ return dim_list;
+}
+
+// Checks for valid arguments to linalg_norm when type(ord) == str
+static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) {
+ TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord);
+ TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument");
+ bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2);
+ TORCH_CHECK(dims_valid, "order \"", str_ord,
+ "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)");
+}
+
+// Performs vector norm for ord = +/-infinity, and the second dimension reduction
+// for matrix norms.
+static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) {
+ Tensor result;
+ if (self.numel() == 0 && self.sizes()[dim] > 0) {
+ // This special case is needed in matrix norm for tensors with 3 or more dims,
+ // or in vector norm for order inf and -inf for tesnsors with 2 or more dims.
+ // When the sizes of the dims to be reduced are greater than 0 but another dim
+ // in the tensor is size 0 (thus numel == 0), we must either flatten or resize
+ // the second reduction dim to 1, to avoid calling min/max, which would throw
+ // an error.
+ if (self.sizes()[dim] != 1) {
+ auto new_sizes = self.sizes().vec();
+ new_sizes[dim] = 1;
+ self.resize_(new_sizes);
+ }
+ result = keepdim ? self : self.flatten(dim);
+ } else {
+ if (ord > 0) {
+ result = std::get<0>(self.max(dim, keepdim));
+ } else {
+ result = std::get<0>(self.min(dim, keepdim));
+ }
+ }
+ return result;
+}
+
+// Performs matrix norm
+static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord,
+ IntArrayRef dim, bool keepdim, optional opt_dtype) {
+ Tensor result;
+ auto ord = opt_ord.value_or(2.0).toDouble();
+ TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
+ "matrix norm only supports CPU AND CUDA device type, got: ", self.device().type());
+ TORCH_CHECK(self.layout() == Layout::Strided,
+ "matrix norm only supports strided layout, got: ", self.layout());
+
+ TORCH_CHECK(dim.size() == 2, "_linalg_norm_matrix: 'dim' must either specify 2 dimensions. ",
+ "Got 'dim' specifying ", dim.size(), " dims");
+ auto dim_ = dim.vec();
+ maybe_wrap_dims(dim_, self.dim());
+ TORCH_CHECK(dim_[0] != dim_[1],
+ "Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead");
+
+ ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
+ TORCH_CHECK(
+ at::isFloatingType(scalarType) || at::isComplexType(scalarType),
+ "Can only calculate the mean of floating and complex types. Got ",
+ toString(scalarType), " instead.");
+
+ Tensor self_;
+ if (opt_dtype.has_value()) {
+ self_ = self.to(scalarType);
+ } else {
+ self_ = self;
+ }
+
+ if (std::abs(ord) == 2) {
+ // Need to shift the reduction dims to the back, because at::svd will only operate on
+ // the last 2 dimensions
+ auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
+ auto permutation_reverse = create_reverse_permutation(permutation);
+
+ result = std::get<1>(self_.permute(permutation).svd()).abs();
+ result = _norm_min_max(result, ord, result.dim() - 1, keepdim);
+
+ if (keepdim) {
+ result.unsqueeze_(-1);
+ result = result.permute(permutation_reverse);
+ }
+ } else {
+ // abs(p) == infinity and abs(p) == 1 will perform identical reductions, except
+ // that the order of the two dims is swapped. So we can swap the dims if
+ // abs(p) == infinity to simplify the rest of the operation's logic.
+ if (std::abs(ord) == INFINITY) {
+ std::swap(dim_[0], dim_[1]);
+ }
+ // If the dim of the second reduction is greater than that of the first reduction
+ // and we are not keeping the dims, then the fact that the output of the first
+ // reduction will have one fewer dimension means that the second reduction dim
+ // will be off by one, so we need to correct that.
+ if ((dim_[1] > dim_[0]) && !keepdim) {
+ dim_[1]--;
+ }
+ if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) {
+ result = self_.abs().sum(dim_[0], keepdim);
+ result = _norm_min_max(result, ord, dim_[1], keepdim);
+ } else {
+ TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm");
+ }
}
return result;
}
+// Performs vector norm
+// This function mostly serves as a wrapper for at::norm, but it overrides a few cases
+// for numpy compatibility. These cases are corrected within this wrapper, rather than
+// in at::norm itself, to avoid breaking backward compatibility.
+static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) {
+ if (opt_ord.has_value()) {
+ TORCH_INTERNAL_ASSERT(dim.size() == 1);
+ auto ord = opt_ord.value().toDouble();
+ Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self;
+ if (std::abs(ord) == INFINITY) {
+ // The ord = +/-infinity case is overridden because at::norm does not match numpy
+ // when the input contains extreme values (like nan or +/-inf) or if the input
+ // size is degenerate (like size(0), size(0, N), etc)
+ self_ = self_.abs();
+ return _norm_min_max(self_, ord, dim[0], keepdim);
+ } else if ((self_.numel() == 0) && (ord < 0)) {
+ // For negative orders with degenerate input sizes, at::norm's result does not
+ // match numpy.
+ Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim);
+ if (ord >= -1) {
+ // Result must be infinite in this case, and the simplest way to make that
+ // happen is to simply add infinity
+ result += INFINITY;
+ } else {
+ result = result.pow(1.0 / (ord + 1));
+ }
+ return result;
+ }
+ } else {
+ // If ord == None, need to check for unique dims because at::norm does not check it
+ // for this case.
+ std::vector dim_(dim);
+ maybe_wrap_dims(dim_, self.dim());
+ bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end();
+ TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")");
+ }
+ if (opt_dtype.has_value()) {
+ return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value());
+ } else {
+ return at::norm(self, opt_ord, dim, keepdim);
+ }
+}
+
+static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) {
+ // Callers must give the ord argument as either a number, a string, or neither.
+ // Since the user-facing API has no direct control over how this function is called, this is an internal assert.
+ TORCH_INTERNAL_ASSERT(!(opt_num_ord.has_value() && opt_str_ord.has_value()));
+ if (opt_dtype.has_value()) {
+ auto dtype = opt_dtype.value();
+ TORCH_CHECK(dtype == result.scalar_type(), "provided dtype must match dtype of result, but got",
+ "dtype = ", dtype, ", out.dtype = ", result.scalar_type());
+ }
+ int64_t ndim = self.dim();
+ Tensor result_;
+ if (opt_str_ord.has_value()) {
+ // 'ord' is string
+ auto str_ord = opt_str_ord.value();
+ check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype);
+ if (str_ord == "fro") {
+ result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim);
+ } else if (str_ord == "nuc") {
+ if (opt_dim.has_value()) {
+ result_ = at::nuclear_norm(self, opt_dim.value(), keepdim);
+ } else {
+ result_ = at::nuclear_norm(self, keepdim);
+ }
+ }
+ } else {
+ // 'ord' is int or None
+ std::vector dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim);
+ if (!opt_num_ord.has_value() || dim_.size() == 1) {
+ result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype);
+ } else if (dim_.size() == 2) {
+ result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype);
+ } else {
+ TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is "
+ "not 1-D or 2-D");
+ }
+ }
+ resize_output(result, result_.sizes());
+ result.copy_(result_);
+ return result;
+}
+
+// Numerical or None norms
+Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) {
+ auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
+ Tensor result = at::empty({0}, options);
+ return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
+}
+
+// Frobenius and nuclear norms
+Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) {
+ auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device());
+ Tensor result = at::empty({0}, options);
+ return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
+}
+
+// Numerical or None norms
+Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) {
+ return linalg_norm_out_impl(result, self, opt_ord, c10::nullopt, opt_dim, keepdim, opt_dtype);
+}
+
+// Frobenius and nuclear norms
+Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) {
+ return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype);
+}
+
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) {
if (i == j)
return matrices[i];
diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp
index 5587e013949ec5d..414c8a6f6390413 100644
--- a/aten/src/ATen/native/Pow.cpp
+++ b/aten/src/ATen/native/Pow.cpp
@@ -11,8 +11,7 @@ DEFINE_DISPATCH(pow_tensor_tensor_stub);
DEFINE_DISPATCH(pow_tensor_scalar_stub);
Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) {
- auto iter = TensorIterator::binary_op(result, base, exp,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(result, base, exp);
pow_tensor_tensor_stub(iter.device_type(), iter);
return result;
}
@@ -37,8 +36,7 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
} else if (!exp.isComplex() && (exp.toDouble() == 1.0)) {
result.resize_as_(base).copy_(base);
} else {
- auto iter = TensorIterator::unary_op(result, base.to(common_dtype),
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, base.to(common_dtype));
pow_tensor_scalar_stub(iter.device_type(), iter, exp);
}
return result;
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index 80d97a7e8a0c58b..5e66819fb98e203 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1858,80 +1858,58 @@ static auto cell_params_base_registry =
return cell_params_deserializers[type](std::move(state));
});
-static auto registry =
- torch::RegisterOperators()
- .op("aten::quantized_lstm.input(Tensor input, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel(
- DispatchKey::CPUTensorId))
- .op("aten::quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel(
- DispatchKey::CPUTensorId))
- .op("aten::quantized_lstm.input_legacy(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_lstm_input_legacy),
- quantized_lstm_input_legacy>(DispatchKey::CPUTensorId))
- .op("aten::quantized_lstm.data_legacy(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_lstm_data_legacy),
- quantized_lstm_data_legacy>(DispatchKey::CPUTensorId))
- .op("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase",
- torch::RegisterOperators::options()
- .kernel<
- decltype(make_quantized_cell_params_dynamic),
- make_quantized_cell_params_dynamic>(
- DispatchKey::CPUTensorId))
- .op("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
- torch::RegisterOperators::options()
- .catchAllKernel<
- decltype(make_quantized_cell_params_fp16),
- &make_quantized_cell_params_fp16>())
- .op("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
- torch::RegisterOperators::options()
- .kernel<
- decltype(make_quantized_cell_params),
- make_quantized_cell_params>(DispatchKey::CPUTensorId))
- .op("aten::quantized_gru.input(Tensor input, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel(
- DispatchKey::CPUTensorId))
- .op("aten::quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel(
- DispatchKey::CPUTensorId))
- .op("aten::quantized_gru.input_legacy(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_gru_input_legacy),
- quantized_gru_input_legacy>(DispatchKey::CPUTensorId))
- .op("aten::quantized_gru.data_legacy(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_gru_data_legacy),
- quantized_gru_data_legacy>(DispatchKey::CPUTensorId))
- .op("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_lstm_cell_dynamic),
- quantized_lstm_cell_dynamic>(DispatchKey::CPUTensorId))
- .op("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_gru_cell_dynamic),
- quantized_gru_cell_dynamic>(DispatchKey::CPUTensorId))
- .op("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_rnn_relu_cell_dynamic),
- quantized_rnn_relu_cell_dynamic>(DispatchKey::CPUTensorId))
- .op("quantized::quantized_rnn_tanh_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor",
- torch::RegisterOperators::options()
- .kernel<
- decltype(quantized_rnn_tanh_cell_dynamic),
- quantized_rnn_tanh_cell_dynamic>(DispatchKey::CPUTensorId));
+TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) {
+ m.def(
+ "quantized_lstm.input(Tensor input, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)");
+ m.def(
+ "quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)");
+ m.def(
+ "quantized_lstm.input_legacy(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)");
+ m.def(
+ "quantized_lstm.data_legacy(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)");
+ m.def(
+ "quantized_gru.input(Tensor input, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)");
+ m.def(
+ "quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)");
+ m.def(
+ "quantized_gru.input_legacy(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)");
+ m.def(
+ "quantized_gru.data_legacy(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)");
+}
+
+TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(quantized, m) {
+ m.def("make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase");
+ m.def("make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase");
+ m.def("make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase");
+ m.def("quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)");
+ m.def("quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor");
+ m.def("quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor");
+ m.def("quantized_rnn_tanh_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor");
+}
+
+TORCH_LIBRARY_IMPL(aten, CPU, m) {
+ m.impl("quantized_lstm.input", TORCH_FN(quantized_lstm_input));
+ m.impl("quantized_lstm.data", TORCH_FN(quantized_lstm_data));
+ m.impl("quantized_lstm.input_legacy", TORCH_FN(quantized_lstm_input_legacy));
+ m.impl("quantized_lstm.data_legacy", TORCH_FN(quantized_lstm_data_legacy));
+ m.impl("quantized_gru.input", TORCH_FN(quantized_gru_input));
+ m.impl("quantized_gru.data", TORCH_FN(quantized_gru_data));
+ m.impl("quantized_gru.input_legacy", TORCH_FN(quantized_gru_input_legacy));
+ m.impl("quantized_gru.data_legacy", TORCH_FN(quantized_gru_data_legacy));
+}
+
+TORCH_LIBRARY_IMPL(quantized, CPU, m) {
+ m.impl("make_quantized_cell_params_dynamic", TORCH_FN(make_quantized_cell_params_dynamic));
+ m.impl("make_quantized_cell_params", TORCH_FN(make_quantized_cell_params));
+ m.impl("quantized_lstm_cell_dynamic", TORCH_FN(quantized_lstm_cell_dynamic));
+ m.impl("quantized_gru_cell_dynamic", TORCH_FN(quantized_gru_cell_dynamic));
+ m.impl("quantized_rnn_relu_cell_dynamic", TORCH_FN(quantized_rnn_relu_cell_dynamic));
+ m.impl("quantized_rnn_tanh_cell_dynamic", TORCH_FN(quantized_rnn_tanh_cell_dynamic));
+}
+
+TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {
+ m.impl("make_quantized_cell_params_fp16", TORCH_FN(make_quantized_cell_params_fp16));
+}
} // namespace
}} // namespace at::native
diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp
index 8deff003be12282..c2fdde8c58aeea1 100644
--- a/aten/src/ATen/native/RangeFactories.cpp
+++ b/aten/src/ATen/native/RangeFactories.cpp
@@ -173,7 +173,7 @@ Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
}
Tensor r = result.is_contiguous() ? result : result.contiguous();
- auto iter = TensorIterator::nullary_op(r, /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::nullary_op(r);
arange_stub(iter.device_type(), iter, start, size, step);
if (!result.is_contiguous()) {
result.copy_(r);
diff --git a/aten/src/ATen/native/RowwisePrune.cpp b/aten/src/ATen/native/RowwisePrune.cpp
new file mode 100644
index 000000000000000..c373f3a64701697
--- /dev/null
+++ b/aten/src/ATen/native/RowwisePrune.cpp
@@ -0,0 +1,106 @@
+// Copyright 2004-present Facebook. All Rights Reserved.
+
+#include
+
+
+namespace at {
+namespace native {
+
+namespace {
+
+template
+std::tuple _rowwise_prune_helper(
+ const Tensor& weights, const Tensor& mask,
+ ScalarType compressed_indices_dtype) {
+ int num_non_masked_rows = 0;
+ auto mask_contig = mask.contiguous();
+ auto mask_data = mask_contig.data_ptr();
+ for (int i = 0; i < mask.numel(); ++i) {
+ num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0);
+ }
+ int num_cols = weights.size(1);
+ auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
+ weights.options());
+ auto compressed_indices_mapping = at::empty({mask.numel()},
+ compressed_indices_dtype);
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ weights.scalar_type(),
+ "rowwise_prune_helper", [&]() {
+ auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr();
+ auto compressed_indices_mapping_data =
+ compressed_indices_mapping.data_ptr();
+ auto weights_data = weights.data_ptr();
+ int last_row_kept = 0;
+ for (int i = 0; i < mask.numel(); i++) {
+ if (mask_data[i]) {
+ memcpy(pruned_2d_tensor_data + last_row_kept * num_cols,
+ weights_data + i * num_cols,
+ num_cols * sizeof (scalar_t));
+ compressed_indices_mapping_data[i] = last_row_kept;
+ last_row_kept++;
+ } else {
+ compressed_indices_mapping_data[i] = -1;
+ }
+ }
+ });
+ return std::tuple(pruned_2d_tensor,
+ compressed_indices_mapping);
+}
+
+} // namespace
+
+
+// This operator introduces sparsity to the 'weights' matrix with the help
+// of the importance indicator 'mask'.
+//
+// A row is considered important and not pruned if the mask value for that
+// particular row is 1(True) and not important otherwise.
+//
+// This operator doesn't zero out the pruned rows in-place. Instead, it
+// returns a tuple that contains a pruned weights tensor as well as a map that
+// can be used to look up the original row in the pruned weights tensor.
+// We refer this map as 'compressed indices map' going forward.
+
+// The 'compressed indices map' is an 1D tensor that contains one entry per
+// original row in 'weights'. The array index is the index for the original
+// non-pruned weight tensor and the value would be the re-mapped index in the
+// pruned weights tensor. If the value for a index is -1, it means the
+// corresponding row has been pruned from the original weight tensor.
+
+// Arguments:
+// 'weights' - two dimensional matrix that needs to be prune.
+// 'mask' - 1D boolean tensor that represents whether a row is important or
+// not. A mask value of 1 means the row should be kept and 0 means the row
+// should be pruned.
+//
+// Returns:
+// A tuple containing two tensors,
+// 1. A pruned weight tensor that contains only the weights that are preserved
+// post pruning.
+// 2. An 1D tensor that contains the mapping between original weight row and
+// the corresponding row in the pruned weights tensor.
+std::tuple rowwise_prune(const Tensor& weights,
+ const Tensor& mask,
+ ScalarType compressed_indices_dtype) {
+ TORCH_CHECK(weights.ndimension() == 2,
+ "'weights' should have 2 dimensions.");
+ TORCH_CHECK(
+ mask.numel() == weights.size(0),
+ "Number of elements in 'mask' should be equivalent to the "
+ "number of rows in 'weights'."
+ )
+ TORCH_CHECK(
+ compressed_indices_dtype == ScalarType::Int ||
+ compressed_indices_dtype == ScalarType::Long,
+ "compressed_indices_dtype should be either int(int32) or long(int64).");
+
+ if (compressed_indices_dtype == at::ScalarType::Int) {
+ return _rowwise_prune_helper(weights, mask,
+ compressed_indices_dtype);
+ }
+ return _rowwise_prune_helper(weights, mask,
+ compressed_indices_dtype);
+}
+
+}} // namesapce at::native
diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h
index a1f324996698585..765c9937277fe82 100644
--- a/aten/src/ATen/native/SharedReduceOps.h
+++ b/aten/src/ATen/native/SharedReduceOps.h
@@ -344,15 +344,29 @@ namespace detail {
template
struct LessOrNan {
- C10_DEVICE bool operator () (scalar_t a, scalar_t b) const {
- return at::_isnan(a) || a < b;
+ C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
+ // If (a == b), then choose the one with lower idx, else min(a, b)
+ if (at::_isnan(a)) {
+ if (at::_isnan(b)) {
+ return idx_a < idx_b;
+ }
+ return true;
+ }
+ return (a == b) ? idx_a < idx_b : (a < b);
}
};
template
struct GreaterOrNan {
- C10_DEVICE bool operator () (scalar_t a, scalar_t b) const {
- return at::_isnan(a) || a > b;
+ C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
+ // If (a == b), then choose the one with lower idx, else max(a, b)
+ if (at::_isnan(a)) {
+ if (at::_isnan(b)) {
+ return idx_a < idx_b;
+ }
+ return true;
+ }
+ return (a == b) ? idx_a < idx_b : (a > b);
}
};
@@ -367,11 +381,11 @@ struct MinMaxReductionOps {
}
static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
- return comp_t{}(arg.first, val) ? arg : arg_t(val, idx);
+ return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
}
static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
- return comp_t{}(a.first, b.first) ? a : b;
+ return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
}
static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 09c2e2261d2af36..3d333923c8b9c51 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -54,6 +54,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -353,6 +354,10 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
"index_add_(): Number of indices should be equal to self.size(dim)");
+ at::assert_no_internal_overlap(self);
+ at::assert_no_partial_overlap(self, index);
+ at::assert_no_partial_overlap(self, source);
+
auto index_contig = index.contiguous();
auto index_data = index_contig.data_ptr();
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index 8ad277264681a70..2eca94b36b37d48 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -816,9 +816,9 @@ void TensorIterator::select_all_keeping_dim(int start_dim, IntArrayRef indices)
}
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
- const Tensor& b, bool check_mem_overlap) {
+ const Tensor& b) {
return TensorIteratorConfig()
- .set_check_mem_overlap(check_mem_overlap)
+ .set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
@@ -830,9 +830,9 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
}
TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
- const Tensor& b, bool check_mem_overlap) {
+ const Tensor& b) {
return TensorIteratorConfig()
- .set_check_mem_overlap(check_mem_overlap)
+ .set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
@@ -841,10 +841,9 @@ TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
.build();
}
-TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a,
- bool check_mem_overlap) {
+TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {
return TensorIteratorConfig()
- .set_check_mem_overlap(check_mem_overlap)
+ .set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.cast_common_dtype_to_outputs(false)
@@ -853,10 +852,10 @@ TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a,
.build();
}
-TensorIterator TensorIterator::nullary_op(Tensor& out, bool check_mem_overlap) {
+TensorIterator TensorIterator::nullary_op(Tensor& out) {
return TensorIteratorConfig()
+ .set_check_mem_overlap(true)
.check_all_same_dtype(false)
- .set_check_mem_overlap(check_mem_overlap)
.add_output(out)
// FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342
.resize_outputs(false)
diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h
index 9f58347c6a0bdcf..27ef0b8eda837b3 100644
--- a/aten/src/ATen/native/TensorIterator.h
+++ b/aten/src/ATen/native/TensorIterator.h
@@ -156,14 +156,10 @@ struct CAFFE2_API TensorIterator {
void foreach_reduced_elt(loop_subiter_t loop, bool parallelize=true);
- static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b,
- bool check_mem_overlap = false);
- static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b,
- bool check_mem_overlap = false);
- static TensorIterator unary_op(Tensor& out, const Tensor& a,
- bool check_mem_overlap = false);
- static TensorIterator nullary_op(Tensor& out,
- bool check_mem_overlap = false);
+ static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b);
+ static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b);
+ static TensorIterator unary_op(Tensor& out, const Tensor& a);
+ static TensorIterator nullary_op(Tensor& out);
static TensorIterator reduce_op(Tensor& out, const Tensor& a);
static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);
diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp
index 337ea73393ec68f..48dab43b2dc8132 100644
--- a/aten/src/ATen/native/TensorProperties.cpp
+++ b/aten/src/ATen/native/TensorProperties.cpp
@@ -53,20 +53,14 @@ bool cudnn_is_acceptable(const Tensor& self) {
}
Tensor detach(const Tensor& self) {
-#ifndef USE_STATIC_DISPATCH
// this just exists to give us a hook in VariableType and an entry in Declarations.yaml
//AT_ERROR("detach is not implemented for Tensor");
-#endif
- // this is no-op for USE_STATIC_DISPATCH mode
return self;
}
Tensor & detach_(Tensor & self) {
-#ifndef USE_STATIC_DISPATCH
// this just exists to give us a hook in VariableType and an entry in Declarations.yaml
//AT_ERROR("detach_ is not implemented for Tensor");
-#endif
- // this is no-op for USE_STATIC_DISPATCH mode
return self;
}
diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp
index 50203aef98a5d20..73ad9840b0bead1 100644
--- a/aten/src/ATen/native/UnaryOps.cpp
+++ b/aten/src/ATen/native/UnaryOps.cpp
@@ -39,8 +39,7 @@ namespace native {
// operators (more is foreseeable) and is more flexible and elegant than the latter.
template
static inline Tensor& unary_op_impl_out(Tensor& result, const Tensor& self, Stub& stub) {
- auto iter = TensorIterator::unary_op(result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
stub(iter.device_type(), iter);
return result;
}
@@ -61,8 +60,7 @@ static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, co
// Runs the function complex->complex, as TensorIterator expects
Tensor complex_result = at::empty({0}, self.options());
- auto iter = TensorIterator::unary_op(complex_result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(complex_result, self);
stub(iter.device_type(), iter);
// Copies the complex result to the actual result and returns it
@@ -108,9 +106,9 @@ Tensor acos(const Tensor& self) { return unary_op_impl(self, at::acos_out); }
Tensor& acos_(Tensor& self) { return unary_op_impl_(self, at::acos_out); }
// arccos, alias for acos
-Tensor& arccos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acos_stub); }
-Tensor arccos(const Tensor& self) { return unary_op_impl(self, at::acos_out); }
-Tensor& arccos_(Tensor& self) { return unary_op_impl_(self, at::acos_out); }
+Tensor& arccos_out(Tensor& result, const Tensor& self) { return at::acos_out(result, self); }
+Tensor arccos(const Tensor& self) { return self.acos(); }
+Tensor& arccos_(Tensor& self) { return self.acos_(); }
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
@@ -140,18 +138,18 @@ Tensor asin(const Tensor& self) { return unary_op_impl(self, at::asin_out); }
Tensor& asin_(Tensor& self) { return unary_op_impl_(self, at::asin_out); }
// arcsin, alias of asin
-Tensor& arcsin_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asin_stub); }
-Tensor arcsin(const Tensor& self) { return unary_op_impl(self, at::asin_out); }
-Tensor& arcsin_(Tensor& self) { return unary_op_impl_(self, at::asin_out); }
+Tensor& arcsin_out(Tensor& result, const Tensor& self) { return at::asin_out(result, self); }
+Tensor arcsin(const Tensor& self) { return self.asin(); }
+Tensor& arcsin_(Tensor& self) { return self.asin_(); }
Tensor& atan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atan_stub); }
Tensor atan(const Tensor& self) { return unary_op_impl(self, at::atan_out); }
Tensor& atan_(Tensor& self) { return unary_op_impl_(self, at::atan_out); }
// arctan, alias of atan
-Tensor& arctan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atan_stub); }
-Tensor arctan(const Tensor& self) { return unary_op_impl(self, at::atan_out); }
-Tensor& arctan_(Tensor& self) { return unary_op_impl_(self, at::atan_out); }
+Tensor& arctan_out(Tensor& result, const Tensor& self) { return at::atan_out(result, self); }
+Tensor arctan(const Tensor& self) { return self.atan(); }
+Tensor& arctan_(Tensor& self) { return self.atan_(); }
// Note [Complex abs and angle]
// Complex inputs to abs and angle return float results by default.
@@ -314,10 +312,20 @@ Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out
Tensor asinh(const Tensor& self) { return unary_op_impl(self, at::asinh_out); }
Tensor& asinh_(Tensor& self) { return unary_op_impl_(self, at::asinh_out); }
+// arcsinh, alias for asinh
+Tensor& arcsinh_out(Tensor& result, const Tensor& self) { return at::asinh_out(result, self); }
+Tensor arcsinh(const Tensor& self) { return self.asinh(); }
+Tensor& arcsinh_(Tensor& self) { return self.asinh_(); }
+
Tensor& atanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atanh_stub); }
Tensor atanh(const Tensor& self) { return unary_op_impl(self, at::atanh_out); }
Tensor& atanh_(Tensor& self) { return unary_op_impl_(self, at::atanh_out); }
+// arctanh, alias for atanh
+Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(result, self); }
+Tensor arctanh(const Tensor& self) { return self.atanh(); }
+Tensor& arctanh_(Tensor& self) { return self.atanh_(); }
+
Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); }
Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); }
Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); }
@@ -333,10 +341,7 @@ Tensor& logit_out(
Tensor& result,
const Tensor& self,
c10::optional eps) {
- auto iter = TensorIterator::unary_op(
- result,
- self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
logit_stub(iter.device_type(), iter, Scalar(eps ? eps.value() : -1.0));
return result;
}
@@ -435,8 +440,7 @@ Tensor& clamp_out(Tensor& result, const Tensor& self, optional min, opti
if (min && max) {
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp only supports strided layout, got: ", self.layout());
- auto iter = TensorIterator::unary_op(result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
clamp_stub(iter.device_type(), iter, *min, *max);
} else if (max) {
at::clamp_max_out(result, self, *max);
@@ -461,8 +465,7 @@ Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) {
TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors.");
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp_max only supports strided layout, got: ", self.layout());
- auto iter = TensorIterator::unary_op(result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
clamp_max_stub(iter.device_type(), iter, max);
return result;
}
@@ -480,8 +483,7 @@ Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) {
TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors.");
TORCH_CHECK(self.layout() == Layout::Strided,
"clamp_min only supports strided layout, got: ", self.layout());
- auto iter = TensorIterator::unary_op(result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
clamp_min_stub(iter.device_type(), iter, min);
return result;
}
@@ -518,8 +520,7 @@ Tensor& polygamma_(Tensor& self, int64_t n) {
}
Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self) {
TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n.");
- auto iter = TensorIterator::unary_op(result, self,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::unary_op(result, self);
polygamma_stub(iter.device_type(), iter, n);
return result;
}
@@ -563,8 +564,7 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
Tensor& _##op##_out_##prefix(Tensor& result, const Tensor& self) { \
checkDeviceType(#op, result, DeviceType::device); \
checkLayout(#op, result, Layout::Strided); \
- auto iter = TensorIterator::unary_op(result, self, \
- /*check_mem_overlap=*/true); \
+ auto iter = TensorIterator::unary_op(result, self); \
op##_stub(iter.device_type(), iter); \
return result; \
}
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 946c0a7276150a9..133d51bd99a89ec 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -318,14 +318,14 @@ void rshift_kernel(TensorIterator& iter) {
void lt_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
- AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, iter.input_dtype(), "lt_cpu", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "lt_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
} else {
- AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "lt_cpu", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "lt_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
diff --git a/aten/src/ATen/native/cpu/FillKernel.cpp b/aten/src/ATen/native/cpu/FillKernel.cpp
index be9f1a445d61dce..69a304257fca777 100644
--- a/aten/src/ATen/native/cpu/FillKernel.cpp
+++ b/aten/src/ATen/native/cpu/FillKernel.cpp
@@ -12,7 +12,7 @@ namespace {
template
-static void fill_non_native_type(TensorIterator& iter, Scalar value_scalar) {
+void fill_non_native_type(TensorIterator& iter, Scalar value_scalar) {
auto value = value_scalar.to().x;
using H = typename std::make_signed::type; // Signed type has more acceleration
// Reserve the representation of value. static_cast(value) is implementation defined.
@@ -23,11 +23,24 @@ static void fill_non_native_type(TensorIterator& iter, Scalar value_scalar) {
[val]() { return Vec256(val); });
}
+template <>
+void fill_non_native_type>(TensorIterator& iter, Scalar value_scalar) {
+ static_assert(sizeof(c10::complex) == sizeof(int32_t), "Size of ComplexHalf should be 32-bits");
+ auto value = c10::complex(value_scalar.to>());
+ auto val = *reinterpret_cast(std::addressof(value));
+ cpu_kernel_vec*check_dynamic_cast=*/false>(
+ iter,
+ [val]() -> int32_t { return val; },
+ [val]() { return Vec256(val); });
+}
+
void fill_kernel(TensorIterator& iter, Scalar value_scalar) {
if (iter.dtype() == ScalarType::Half) {
fill_non_native_type(iter, value_scalar);
} else if (iter.dtype() == ScalarType::BFloat16) {
fill_non_native_type(iter, value_scalar);
+ } else if (iter.dtype() == ScalarType::ComplexHalf) {
+ fill_non_native_type>(iter, value_scalar);
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, iter.dtype(), "fill_cpu", [&]() {
scalar_t value = value_scalar.to();
diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp
index 840b6ce5d75d707..9ad80ec5daeb614 100644
--- a/aten/src/ATen/native/cpu/IndexKernel.cpp
+++ b/aten/src/ATen/native/cpu/IndexKernel.cpp
@@ -98,7 +98,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
}
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
iter.dtype(), "index_cpu", [&] {
cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset);
@@ -108,11 +108,11 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
// NOTE: duplicate indices are only supported if accumulate is true.
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
iter.dtype(), "index_put", [&] {
if (accumulate) {
bool use_parallel_for = ((iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
- if (iter.dtype() == at::ScalarType::Float && use_parallel_for) {
+ if (iter.dtype() == ScalarType::Float && use_parallel_for) {
cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
cpu_atomic_add_float((float*)(dst + offset), *(float*)src);
});
@@ -151,11 +151,11 @@ void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) {
}
void masked_fill_kernel(TensorIterator& iter, Scalar value) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
iter.dtype(), "masked_fill", [&] {
scalar_t scalar_val = value.to();
auto mask_dtype = iter.input_dtype(0);
- if (mask_dtype == at::ScalarType::Bool) {
+ if (mask_dtype == ScalarType::Bool) {
cpu_masked_fill_kernel(iter, scalar_val);
} else {
cpu_masked_fill_kernel(iter, scalar_val);
@@ -187,10 +187,10 @@ void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) {
}
void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
iter.dtype(), "masked_select", [&] {
auto mask_dtype = iter.input_dtype(1);
- if (mask_dtype == at::ScalarType::Bool) {
+ if (mask_dtype == ScalarType::Bool) {
cpu_masked_select_serial_kernel(iter, [result_stride](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
});
@@ -226,10 +226,10 @@ void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) {
}
void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
iter.dtype(), "masked_select", [&] {
auto mask_dtype = iter.input_dtype(1);
- if (mask_dtype == at::ScalarType::Bool) {
+ if (mask_dtype == ScalarType::Bool) {
cpu_masked_select_kernel(iter, [result_stride](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
});
diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp
index dceefdd6efc061a..a22e349a9cf641b 100644
--- a/aten/src/ATen/native/cpu/LerpKernel.cpp
+++ b/aten/src/ATen/native/cpu/LerpKernel.cpp
@@ -15,8 +15,7 @@ static void lerp_kernel_scalar(
const Tensor& end,
Scalar weight) {
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype());
- auto iter = TensorIterator::binary_op(ret, self, end,
- /*check_mem_overlap=*/true);
+ auto iter = TensorIterator::binary_op(ret, self, end);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] {
using value_t = typename c10::scalar_value_type::type;
scalar_t weight_val = weight.to();
diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
index 40938b97e914f68..cce011a5a3127b2 100644
--- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
@@ -292,7 +292,7 @@ static void argmax_kernel_impl(TensorIterator &iter) {
binary_kernel_reduce(
iter,
ArgMaxOps{},
- std::pair(lower_bound(), -1));
+ std::pair(lower_bound(), 0));
});
}
@@ -301,7 +301,7 @@ static void argmin_kernel_impl(TensorIterator &iter) {
binary_kernel_reduce(
iter,
ArgMinOps{},
- std::pair(upper_bound(), -1));
+ std::pair(upper_bound(), 0));
});
}
diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu
index 359fd5b7bf3e214..b8af7155cf4838f 100644
--- a/aten/src/ATen/native/cuda/Activation.cu
+++ b/aten/src/ATen/native/cuda/Activation.cu
@@ -29,10 +29,7 @@ void prelu_cuda_kernel_share_weights(
Tensor& result,
const scalar_t* weight_data)
{
- at::TensorIterator iter = TensorIteratorConfig()
- .add_output(result)
- .add_input(input)
- .build();
+ auto iter = TensorIterator::unary_op(result, input);
at::native::gpu_kernel(iter,
[weight_data] GPU_LAMBDA (scalar_t input_val) {
@@ -540,7 +537,16 @@ static Tensor threshold_out_cuda(
Scalar value,
const Tensor& other) {
Tensor result = opt_result.value_or(Tensor());
- auto iter = TensorIterator::binary_op(result, self, other);
+ auto iter = TensorIteratorConfig()
+ .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay
+ .add_output(result)
+ .add_input(self)
+ .add_input(other)
+ .allow_cpu_scalars(true)
+ .promote_inputs_to_common_dtype(true)
+ .cast_common_dtype_to_outputs(true)
+ .enforce_safe_casting_to_output(true)
+ .build();
threshold_kernel(iter, threshold, value);
return iter.output();
}
diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
new file mode 100644
index 000000000000000..05c7c47784038d6
--- /dev/null
+++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh
@@ -0,0 +1,126 @@
+#include
+#include
+
+namespace at { namespace native {
+
+namespace {
+
+template
+struct AddScalarFunctor_ {
+ __device__ void operator() (
+ int chunk_size,
+ TensorListMetadata<1>& tl,
+ T scalar) {
+ int tensor_loc = tl.block_to_tensor[blockIdx.x];
+ int chunk_idx = tl.block_to_chunk[blockIdx.x];
+ int n = tl.sizes[tensor_loc];
+
+ T* x = (T*)tl.addresses[0][tensor_loc];
+ x += chunk_idx * chunk_size;
+
+ n -= chunk_idx * chunk_size;
+
+ T r_x[kILP];
+
+ // to make things simple, we put aligned case in a different code path
+ if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x)) {
+ for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
+ // load
+ load_store(r_x, x, 0 , i_start);
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ r_x[ii] = static_cast(r_x[ii]) + scalar;
+ }
+ // store
+ load_store(x, r_x, i_start, 0);
+ }
+ }
+ else {
+ // Non-divergent exit condition for __syncthreads, not necessary here
+ for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ r_x[ii] = 0;
+ int i = i_start + threadIdx.x + ii * blockDim.x;
+ if(i < n && i < chunk_size) {
+ r_x[ii] = x[i];
+ }
+ }
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ r_x[ii] = static_cast(r_x[ii]) + scalar;
+ }
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ int i = i_start + threadIdx.x + ii * blockDim.x;
+ if(i < n && i < chunk_size)
+ x[i] = r_x[ii];
+ }
+ }
+ }
+ }
+};
+
+template
+struct AddScalarFunctor {
+ __device__ void operator() (
+ int chunk_size,
+ TensorListMetadata<2>& tl,
+ T scalar) {
+ int tensor_loc = tl.block_to_tensor[blockIdx.x];
+ int chunk_idx = tl.block_to_chunk[blockIdx.x];
+ int n = tl.sizes[tensor_loc];
+
+ T* x = (T*)tl.addresses[0][tensor_loc];
+ x += chunk_idx * chunk_size;
+
+ T* out = (T*)tl.addresses[1][tensor_loc];
+ out += chunk_idx * chunk_size;
+
+ n -= chunk_idx * chunk_size;
+
+ T r_x[kILP];
+ T r_out[kILP];
+
+ // to make things simple, we put aligned case in a different code path
+ if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) {
+ for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
+ // load
+ load_store(r_x, x, 0 , i_start);
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ r_out[ii] = static_cast(r_x[ii]) + scalar;
+ }
+ // store
+ load_store(out, r_out, i_start, 0);
+ }
+ }
+ else {
+ // Non-divergent exit condition for __syncthreads, not necessary here
+ for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ r_x[ii] = 0;
+ int i = i_start + threadIdx.x + ii * blockDim.x;
+ if(i < n && i < chunk_size) {
+ r_x[ii] = x[i];
+ }
+ }
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ r_out[ii] = static_cast(r_x[ii]) + scalar;
+ }
+#pragma unroll
+ for(int ii = 0; ii < kILP; ii++) {
+ int i = i_start + threadIdx.x + ii * blockDim.x;
+ if(i < n && i < chunk_size)
+ out[i] = r_out[ii];
+ }
+ }
+ }
+ }
+};
+
+} // namespace
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu b/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu
index 6924aef1ccbe22c..772012c608fba7a 100644
--- a/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu
+++ b/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu
@@ -1,81 +1,14 @@
#include
-#include
-#include
-
-// NOTE: CUDA on Windows requires that the enclosing function
-// of a __device__ lambda not have internal linkage.
+#include
+#include
namespace at { namespace native {
-namespace {
-
-template
-struct AddScalarFunctor {
- __device__ void operator() (
- int chunk_size,
- TensorListMetadata<2>& tl,
- x_t scalar) {
- int tensor_loc = tl.block_to_tensor[blockIdx.x];
- int chunk_idx = tl.block_to_chunk[blockIdx.x];
- int n = tl.sizes[tensor_loc];
-
- x_t* x = (x_t*)tl.addresses[0][tensor_loc];
- x += chunk_idx * chunk_size;
-
- out_t* out = (out_t*)tl.addresses[1][tensor_loc];
- out += chunk_idx * chunk_size;
-
- n -= chunk_idx * chunk_size;
-
- x_t r_x[kILP];
- out_t r_out[kILP];
-
- // to make things simple, we put aligned case in a different code path
- if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) {
- for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
- // load
- load_store(r_x, x, 0 , i_start);
-#pragma unroll
- for(int ii = 0; ii < kILP; ii++) {
- r_out[ii] = static_cast(r_x[ii]) + scalar;
- }
- // store
- load_store(out, r_out, i_start, 0);
- }
- }
- else {
- // Non-divergent exit condition for __syncthreads, not necessary here
- for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
-#pragma unroll
- for(int ii = 0; ii < kILP; ii++) {
- r_x[ii] = 0;
- int i = i_start + threadIdx.x + ii * blockDim.x;
- if(i < n && i < chunk_size) {
- r_x[ii] = x[i];
- }
- }
-#pragma unroll
- for(int ii = 0; ii < kILP; ii++) {
- r_out[ii] = static_cast(r_x[ii]) + scalar;
- }
-#pragma unroll
- for(int ii = 0; ii < kILP; ii++) {
- int i = i_start + threadIdx.x + ii * blockDim.x;
- if(i < n && i < chunk_size)
- out[i] = r_out[ii];
- }
- }
- }
- }
-};
-
-} // namespace
-
std::vector foreach_tensor_add_scalar_kernel_cuda(TensorList tensors, Scalar scalar) {
- TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
+ verify_list(tensors);
if (!check_fast_route(tensors, scalar)) {
- return at::native::foreach_add_scalar_kernel_fallback(tensors, scalar);
+ return at::native::foreach_tensor_add_scalar_kernel_slow(tensors, scalar);
}
std::vector> tensor_lists;
@@ -88,9 +21,24 @@ std::vector foreach_tensor_add_scalar_kernel_cuda(TensorList tensors, Sc
tensor_lists.emplace_back(std::move(vec_res));
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_tensor_add_scalar_kernel_cuda", [&]() {
- multi_tensor_apply<2>(tensor_lists, AddScalarFunctor(), scalar.to());
+ multi_tensor_apply<2>(tensor_lists, AddScalarFunctor(), scalar.to());
});
return tensor_lists[1];
}
+void foreach_tensor_add_scalar_kernel_cuda_(TensorList tensors, Scalar scalar) {
+ verify_list(tensors);
+
+ if (!check_fast_route(tensors, scalar)) {
+ return at::native::foreach_tensor_add_scalar_kernel_slow_(tensors, scalar);
+ }
+
+ std::vector> tensor_lists;
+ tensor_lists.emplace_back(std::move(tensors.vec()));
+
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_tensor_add_scalar_kernel_cuda_", [&]() {
+ multi_tensor_apply<1>(tensor_lists, AddScalarFunctor_(), scalar.to());
+ });
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ForeachUtils.cuh b/aten/src/ATen/native/cuda/ForeachUtils.cuh
deleted file mode 100644
index e3800fb7f6fee06..000000000000000
--- a/aten/src/ATen/native/cuda/ForeachUtils.cuh
+++ /dev/null
@@ -1,56 +0,0 @@
-#pragma once
-#include
-#include
-#include
-namespace at {
-namespace native {
-namespace {
-
-static constexpr int64_t kILP = 4;
-static constexpr int64_t kChunkSize = 65536;
-static constexpr int64_t kBlockSize = 512;
-
-template
-__device__ __forceinline__ bool is_aligned(T* p){
- return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
-}
-
-template
-__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
- using LT = at::native::memory::aligned_vector;
- ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
-}
-
-}
-
-bool check_fast_route(TensorList tensors, Scalar scalar) {
- TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
- auto expected_dtype = tensors[0].dtype();
- auto expected_device = tensors[0].device();
-
- for (auto t : tensors) {
- if (t.dtype() != expected_dtype) {
- return false;
- }
-
- if (t.device() != expected_device) {
- return false;
- }
-
- if (t.layout() != at::kStrided) {
- return false;
- }
-
- if (!t.is_non_overlapping_and_dense()) {
- return false;
- }
-
- if ((at::isIntegralType(t.scalar_type(), true) && scalar.isFloatingPoint()) ||
- t.scalar_type() == at::kBool) {
- return false;
- }
- }
-
- return true;
-}
-}} // at::native
diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu
index 688f035bc4a9edd..b444ca418906a11 100644
--- a/aten/src/ATen/native/cuda/Indexing.cu
+++ b/aten/src/ATen/native/cuda/Indexing.cu
@@ -4,6 +4,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -449,6 +450,10 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
TORCH_CHECK(source.dim() <= MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING);
TORCH_CHECK(index.dim() <= MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING);
+ at::assert_no_internal_overlap(self);
+ at::assert_no_partial_overlap(self, index);
+ at::assert_no_partial_overlap(self, source);
+
// The `source` is partitioned into two parts:
// -the size of each slice we are indexing, which is the
// total size of the tensor ignoring dimension `dim`;
diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp
index cd3ce5cb1df80de..ddc03d5fb10669c 100644
--- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp
+++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp
@@ -3,6 +3,7 @@
#include
#include
#include
+#include
namespace at { namespace native {
@@ -60,6 +61,7 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor &
}
Tensor & fmod_cuda_out(Tensor & result, const Tensor & self, Scalar other) {
+ at::assert_no_internal_overlap(result);
return legacy::cuda::_th_fmod_out(result, self, other);
}
@@ -68,6 +70,7 @@ Tensor fmod_cuda(const Tensor & self, Scalar other) {
}
Tensor & fmod_cuda_out(Tensor & result, const Tensor & self, const Tensor & other) {
+ at::assert_no_internal_overlap(result);
Tensor b_self, b_other;
// optimization that codegen used to do; avoids broadcast.
if (other.dim() == 0) {
@@ -88,6 +91,7 @@ Tensor fmod_cuda(const Tensor & self, const Tensor & other) {
}
Tensor & fmod_cuda_(Tensor & self, Scalar other) {
+ at::assert_no_internal_overlap(self);
return legacy::cuda::_th_fmod_(self, other);
}
@@ -96,6 +100,7 @@ Tensor & fmod_cuda_(Tensor & self, const Tensor & other) {
if (other.dim() == 0) {
return fmod_cuda_(self, other.item());
}
+ at::assert_no_internal_overlap(self);
Tensor b_other;
std::tie(b_other) = expand_inplace(self, other, "fmod_");
return legacy::cuda::_th_fmod_(self, b_other);
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index 633c7fe334f88e4..10301ed6a7c40de 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -111,6 +111,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self_.scalar_type();
if (mat1.numel() == 0) {
+ // By definition, when beta==0, values in self should be ignored. nans and infs
+ // should not propagate
+ if (beta.toComplexDouble() == 0.) {
+ return result.zero_();
+ }
return at::native::mul_out(result, self, at::native::scalar_tensor(beta, at::device(at::kCPU).dtype(self.scalar_type())));
}
diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
index 487c480881f45c0..f82a0d9a58c8ff6 100644
--- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh
+++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh
@@ -1,12 +1,28 @@
#include
#include
-#include
#include
+#include
+#include
namespace at { namespace native {
namespace {
+static constexpr int64_t kILP = 4;
+static constexpr int64_t kChunkSize = 65536;
+static constexpr int64_t kBlockSize = 512;
+
+template
+__device__ __forceinline__ bool is_aligned(T* p){
+ return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
+}
+
+template
+__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
+ using LT = at::native::memory::aligned_vector;
+ ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
+}
+
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
diff --git a/aten/src/ATen/native/mkldnn/BinaryOps.cpp b/aten/src/ATen/native/mkldnn/BinaryOps.cpp
index e0bea2d2d1e9dab..3364fe8b335c683 100644
--- a/aten/src/ATen/native/mkldnn/BinaryOps.cpp
+++ b/aten/src/ATen/native/mkldnn/BinaryOps.cpp
@@ -12,27 +12,27 @@ Tensor& mkldnn_add_out(
const Tensor& self,
const Tensor& other,
Scalar alpha) {
- AT_ERROR("mkldnn_add_out: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_add_out: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_add(const Tensor& self, const Tensor& other, Scalar alpha) {
- AT_ERROR("mkldnn_add: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_add: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) {
- AT_ERROR("mkldnn_add_: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_add_: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
- AT_ERROR("mkldnn_mul_out: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_mul_out: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_mul(const Tensor& self, const Tensor& other) {
- AT_ERROR("mkldnn_mul: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_mul: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) {
- AT_ERROR("mkldnn_mul_: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_mul_: ATen not compiled with MKLDNN support");
}
} // namespace native
@@ -76,7 +76,7 @@ Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) {
}
Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
- AT_ASSERTM(result.sizes() == self.sizes(),
+ TORCH_CHECK(result.sizes() == self.sizes(),
"mkldnn_mul_out: the output size should be same as input size");
ideep::tensor& z = itensor_from_mkldnn(result);
ideep::tensor& x = itensor_from_mkldnn(self);
@@ -89,7 +89,7 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other)
return result;
} else {
- AT_ASSERTM(self.sizes() == other.sizes(),
+ TORCH_CHECK(self.sizes() == other.sizes(),
"mkldnn_mul_out: currently mkldnn not support broadcasting");
ideep::tensor y = itensor_from_mkldnn(other);
ideep::binary::compute(x, y, z, dnnl::algorithm::binary_mul);
diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp
index 0479aba4c54e877..664f7bbd8f1e404 100644
--- a/aten/src/ATen/native/mkldnn/Conv.cpp
+++ b/aten/src/ATen/native/mkldnn/Conv.cpp
@@ -6,28 +6,28 @@
namespace at { namespace native {
-at::Tensor mkldnn_convolution(
- const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
+Tensor mkldnn_convolution(
+ const Tensor& input, const Tensor& weight, const Tensor& bias,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
- AT_ERROR("mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
}
-at::Tensor mkldnn_convolution_backward_input(
- IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
+Tensor mkldnn_convolution_backward_input(
+ IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
- AT_ERROR("mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
}
-std::tuple mkldnn_convolution_backward_weights(
- IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
+std::tuple mkldnn_convolution_backward_weights(
+ IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
- AT_ERROR("mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
}
-std::tuple mkldnn_convolution_backward(
- const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
+std::tuple mkldnn_convolution_backward(
+ const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) {
- AT_ERROR("mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
}
}}
@@ -59,9 +59,9 @@ ideep::tensor _mkldnn_convolution(
const ideep::tensor& x,
const ideep::tensor& w,
const c10::optional& b,
- at::IntArrayRef padding,
- at::IntArrayRef stride,
- at::IntArrayRef dilation,
+ IntArrayRef padding,
+ IntArrayRef stride,
+ IntArrayRef dilation,
int64_t groups) {
auto kernel_size = w.get_dims();
@@ -98,10 +98,10 @@ ideep::tensor _mkldnn_convolution(
return y;
}
-at::Tensor mkldnn_convolution(
- const at::Tensor& input,
- const at::Tensor& weight,
- const at::Tensor& bias,
+Tensor mkldnn_convolution(
+ const Tensor& input,
+ const Tensor& weight,
+ const Tensor& bias,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
@@ -131,7 +131,7 @@ at::Tensor mkldnn_convolution(
}
Tensor mkldnn_convolution_backward_input(
- IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
+ IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
auto mkldnn_grad_output = get_mkldnn_tensor(grad_output);
@@ -153,8 +153,8 @@ Tensor mkldnn_convolution_backward_input(
grad_output.options()));
}
-std::tuple mkldnn_convolution_backward_weights(
- IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
+std::tuple mkldnn_convolution_backward_weights(
+ IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor(grad_output);
@@ -193,8 +193,8 @@ std::tuple mkldnn_convolution_backward_weights(
grad_output.options())));
}
-std::tuple mkldnn_convolution_backward(
- const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
+std::tuple mkldnn_convolution_backward(
+ const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask)
{
Tensor grad_output = grad_output_t.contiguous();
diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp
index 982d55a7081c868..20479ba9164f83a 100644
--- a/aten/src/ATen/native/mkldnn/Linear.cpp
+++ b/aten/src/ATen/native/mkldnn/Linear.cpp
@@ -11,7 +11,7 @@ Tensor mkldnn_linear(
const Tensor& self,
const Tensor& weight,
const Tensor& bias) {
- AT_ERROR("mkldnn_linear: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_linear: ATen not compiled with MKLDNN support");
}
} // namespace native
diff --git a/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp b/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp
index 9d15f74cdeb0c18..cf005e28f7b6672 100644
--- a/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp
+++ b/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp
@@ -11,7 +11,7 @@ namespace at {
namespace native {
Tensor& mkldnn_zero_(Tensor& self) {
- AT_ERROR("mkldnn_zero_: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_zero_: ATen not compiled with MKLDNN support");
}
} // namespace native
diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp
index a2c6ef8704aef0a..86d9d0643a27c4e 100644
--- a/aten/src/ATen/native/mkldnn/Normalization.cpp
+++ b/aten/src/ATen/native/mkldnn/Normalization.cpp
@@ -17,7 +17,7 @@ std::tuple mkldnn_batch_norm(
bool train,
double momentum,
double eps) {
- AT_ERROR("mkldnn_batch_norm: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_batch_norm: ATen not compiled with MKLDNN support");
}
} // namespace native
@@ -49,7 +49,7 @@ std::tuple mkldnn_batch_norm(
if (train) {
// TODO: support training
- AT_ERROR("mkldnn_batch_norm: mkldnn training is not supported in yet.");
+ TORCH_CHECK(false, "mkldnn_batch_norm: mkldnn training is not supported in yet.");
// ideep::tensor saved_mean;
// ideep::tensor saved_var;
@@ -60,7 +60,7 @@ std::tuple mkldnn_batch_norm(
// new_with_itensor_mkldnn(std::move(saved_mean), input.options()),
// new_with_itensor_mkldnn(std::move(saved_var), input.options()));
} else {
- AT_ASSERTM(input.dim() == 4 || input.dim() == 5,
+ TORCH_CHECK(input.dim() == 4 || input.dim() == 5,
"mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm");
ideep::batch_normalization_forward_inference::compute(
x, m, v, w, b, y, eps);
diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp
index 1757c2882e2a8be..a272bc3d6070b9f 100644
--- a/aten/src/ATen/native/mkldnn/Pooling.cpp
+++ b/aten/src/ATen/native/mkldnn/Pooling.cpp
@@ -184,6 +184,8 @@ Tensor mkldnn_max_pool2d(
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
+ TORCH_CHECK(std::all_of(dilation.cbegin(), dilation.cend(), [](int64_t i) { return 1 == i; }),
+ "mkldnn_max_pool2d does not support dilation case");
return _mkldnn_pooling(
input,
kernel_size,
@@ -201,6 +203,8 @@ Tensor mkldnn_max_pool3d(
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
+ TORCH_CHECK(std::all_of(dilation.cbegin(), dilation.cend(), [](int64_t i) { return 1 == i; }),
+ "mkldnn_max_pool3d does not support dilation case");
return _mkldnn_pooling(
input,
kernel_size,
@@ -220,7 +224,7 @@ Tensor mkldnn_avg_pool2d(
bool count_include_pad,
c10::optional divisor_override) {
TORCH_CHECK(!divisor_override.has_value(),
- "mkldnn_avg_pool2d operator does not support divisor");
+ "mkldnn_avg_pool2d operator does not support divisor");
return _mkldnn_pooling(
input,
kernel_size,
diff --git a/aten/src/ATen/native/mkldnn/Relu.cpp b/aten/src/ATen/native/mkldnn/Relu.cpp
index b8615c0772e1c45..42397255caf0048 100644
--- a/aten/src/ATen/native/mkldnn/Relu.cpp
+++ b/aten/src/ATen/native/mkldnn/Relu.cpp
@@ -8,11 +8,11 @@
namespace at { namespace native {
Tensor mkldnn_relu(const Tensor& input) {
- AT_ERROR("mkldnn_relu: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_relu: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_relu_(Tensor& input) {
- AT_ERROR("mkldnn_relu_: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_relu_: ATen not compiled with MKLDNN support");
}
}}
diff --git a/aten/src/ATen/native/mkldnn/SoftMax.cpp b/aten/src/ATen/native/mkldnn/SoftMax.cpp
index 80f94d74b965bec..cdeb6cb85971967 100644
--- a/aten/src/ATen/native/mkldnn/SoftMax.cpp
+++ b/aten/src/ATen/native/mkldnn/SoftMax.cpp
@@ -11,7 +11,7 @@ Tensor mkldnn_softmax(
const Tensor& self,
const int64_t dim,
const bool half_to_float) {
- AT_ERROR("mkldnn_softmax: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_softmax: ATen not compiled with MKLDNN support");
}
} // namespace native
@@ -28,7 +28,7 @@ Tensor mkldnn_softmax(
const Tensor& self,
const int64_t dim,
const bool half_to_float) {
- AT_ASSERTM(
+ TORCH_CHECK(
!half_to_float,
"softmax with half to float conversion is not supported on Mkldnn");
const int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim());
diff --git a/aten/src/ATen/native/mkldnn/TensorFactories.cpp b/aten/src/ATen/native/mkldnn/TensorFactories.cpp
index ec9a7a5429142e2..603819ed3287382 100644
--- a/aten/src/ATen/native/mkldnn/TensorFactories.cpp
+++ b/aten/src/ATen/native/mkldnn/TensorFactories.cpp
@@ -21,7 +21,7 @@ Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::option
#else
Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional optional_memory_format) {
- AT_ERROR("empty_mkldnn: MKL-DNN build is disabled");
+ TORCH_CHECK(false, "empty_mkldnn: MKL-DNN build is disabled");
}
#endif // AT_MKLDNN_ENABLED()
diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp
index af33617a32ed4da..3229a07e94609db 100644
--- a/aten/src/ATen/native/mkldnn/TensorShape.cpp
+++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp
@@ -9,23 +9,23 @@ namespace at {
namespace native {
Tensor mkldnn_view(const Tensor& self, IntArrayRef size) {
- AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_reshape: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
- AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_reshape: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_clone(const Tensor& self, c10::optional optional_memory_format) {
- AT_ERROR("mkldnn_clone: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_clone: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
- AT_ERROR("mkldnn_transpose: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_transpose: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) {
- AT_ERROR("mkldnn_transpose_: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_transpose_: ATen not compiled with MKLDNN support");
}
} // namespace native
@@ -39,7 +39,7 @@ namespace at {
namespace native {
Tensor mkldnn_view(const Tensor& self, IntArrayRef size) {
- AT_ERROR(
+ TORCH_CHECK(false,
"Currently Mkldnn tensor does not support view. Change to use reshape instead");
}
@@ -65,7 +65,7 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional optiona
return new_with_itensor_mkldnn(std::move(dst), self.options());
}
-Tensor mkldnn_transpose(const Tensor & self, int64_t dim0, int64_t dim1) {
+Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
const ideep::tensor& x = itensor_from_mkldnn(self);
ideep::tensor y;
std::vector axes(x.ndims());
@@ -76,7 +76,7 @@ Tensor mkldnn_transpose(const Tensor & self, int64_t dim0, int64_t dim1) {
}
Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) {
- AT_ERROR("mkldnn_transpose_: in-place mkldnn operations are not supported yet");
+ TORCH_CHECK(false, "mkldnn_transpose_: in-place mkldnn operations are not supported yet");
}
} // namespace native
diff --git a/aten/src/ATen/native/mkldnn/UnaryOps.cpp b/aten/src/ATen/native/mkldnn/UnaryOps.cpp
index 5045acd60a57aed..4eb02dc483c5ce4 100644
--- a/aten/src/ATen/native/mkldnn/UnaryOps.cpp
+++ b/aten/src/ATen/native/mkldnn/UnaryOps.cpp
@@ -8,11 +8,11 @@ namespace at {
namespace native {
Tensor mkldnn_sigmoid(const Tensor& self) {
- AT_ERROR("mkldnn_sigmoid: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_sigmoid: ATen not compiled with MKLDNN support");
}
Tensor& mkldnn_sigmoid_(Tensor& self) {
- AT_ERROR("mkldnn_sigmoid_: ATen not compiled with MKLDNN support");
+ TORCH_CHECK(false, "mkldnn_sigmoid_: ATen not compiled with MKLDNN support");
}
} // namespace native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7d55d18fe4f87d5..d4faba561329753 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -495,6 +495,17 @@
- func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+# arcsinh, alias for asinh
+- func: arcsinh(Tensor self) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: arcsinh_(Tensor(a!) self) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
- func: atanh(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
@@ -505,6 +516,17 @@
- func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+# arctanh, alias for atanh
+- func: arctanh(Tensor self) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: arctanh_(Tensor(a!) self) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
- func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method
@@ -1236,6 +1258,9 @@
CPU: _embedding_bag_forward_only_cpu
CUDA: _embedding_bag_forward_only_cuda
+- func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
+ use_c10_dispatcher: full
+
- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
use_c10_dispatcher: full
@@ -3584,6 +3609,26 @@
use_c10_dispatcher: full
variants: method
+# subtract, alias for sub
+- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+
+- func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: method
+
+# For C++ only, until we have conversion from C++ numbers to Tensor
+- func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: method
+
- func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
variants: function
@@ -5727,9 +5772,16 @@
device_guard: False
variants: function
dispatch:
- CPU: foreach_add_scalar_kernel_fallback
+ CPU: foreach_tensor_add_scalar_kernel_slow
CUDA: foreach_tensor_add_scalar_kernel_cuda
+- func: _foreach_add_.Scalar(Tensor[](a!) self, Scalar scalar) -> ()
+ device_guard: False
+ variants: function
+ dispatch:
+ CPU: foreach_tensor_add_scalar_kernel_slow_
+ CUDA: foreach_tensor_add_scalar_kernel_cuda_
+
- func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor)
use_c10_dispatcher: full
dispatch:
@@ -7304,6 +7356,22 @@
- func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+- func: linalg_norm(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+ python_module: linalg
+ variants: function
+
+- func: linalg_norm.ord_str(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+ python_module: linalg
+ variants: function
+
+- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+ python_module: linalg
+ variants: function
+
+- func: linalg_norm.ord_str_out(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+ python_module: linalg
+ variants: function
+
## Functions that are only for testing
# It is undocumented and should not be used outside of tests.
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h
new file mode 100644
index 000000000000000..1d6010490c81d79
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h
@@ -0,0 +1,301 @@
+#pragma once
+
+#include
+#include
+
+#include
+#include
+
+#include
+
+/* Convolution prepacked parameters serialization.
+ *
+ * Version 1
+ *
+ * - Fields:
+ * 1. weight
+ * 2. bias
+ * 3. stride x kSpatialDim
+ * 4. padding x kSpatialDim
+ * 5. dilation x kSpatialDim
+ * 6. groups
+ *
+ * Version 2
+ *
+ * - Fields:
+ * 0. version (string)
+ * 1. list of non-optional tensors
+ * 0: packed parameters (int16_t)
+ * - kSpatialDim
+ * - stride x kSpatialDim
+ * - padding x kSpatialDim
+ * - dilation x kSpatialDim
+ * - output_padding x kSpatialDim (unused)
+ * - groups
+ * - transpose (0 or 1, unused)
+ * 1: weight
+ * 2. list of optional tensors
+ * 0: bias
+ *
+ * Note: version is a string and conv params are packed into a Tensor
+ * to make ONNX happy (ints and containers of ints are not supported).
+ */
+
+// version 1
+using ConvParamsSerializationTypeLegacy = std::tuple<
+ // weight
+ at::Tensor,
+ // bias
+ c10::optional,
+ // stride x kSpatialDim
+ torch::List,
+ // padding x kSpatialDim
+ torch::List,
+ // dilation x kSpatialDim
+ torch::List,
+ // groups
+ at::Tensor>;
+
+// version 2
+using ConvParamsSerializationType = std::tuple<
+ // version, for versions 2 and up
+ std::string,
+ // non-optional tensors
+ std::vector,
+ // optional tensors
+ std::vector>>;
+
+// Parses any historical conv packed params format into
+// the current format.
+template
+ConvParamsSerializationType parse_conv_serialized_state(c10::IValue v) {
+
+ // determine the version based on IValue contents
+ int version = -1;
+ if (v.isTuple()) {
+ auto elements = v.toTuple()->elements();
+ if (elements.size() > 0) {
+ auto firstElement = elements[0];
+ if (firstElement.isTensor()) {
+ version = 1;
+ } else if (firstElement.isString()) {
+ std::string version_str = firstElement.toStringRef();
+ // note: not parsing the string to automatically handle bad
+ // inputs
+ if (version_str == "2") {
+ version = 2;
+ }
+ }
+ }
+ }
+ TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
+
+ if (version == 1) {
+ // version 1 - convert to version 2 manually
+
+ auto elements = v.toTuple()->elements();
+
+ at::Tensor weight = elements[0].toTensor();
+ c10::optional bias = elements[1].toOptional();
+ torch::List stride_x_kSpatialDim = elements[2].toTensorList();
+ torch::List padding_x_kSpatialDim = elements[3].toTensorList();
+ torch::List dilation_x_kSpatialDim = elements[4].toTensorList();
+ at::Tensor groups = elements[5].toTensor();
+
+ std::string version = "2";
+ std::vector non_optional;
+ std::vector> optional;
+
+ std::vector params_vec;
+ params_vec.push_back(kSpatialDim);
+ for (int i = 0; i < stride_x_kSpatialDim.size(); i++) {
+ auto stride = stride_x_kSpatialDim.get(i);
+ params_vec.push_back(stride[0].item());
+ }
+ for (int i = 0; i < padding_x_kSpatialDim.size(); i++) {
+ auto padding = padding_x_kSpatialDim.get(i);
+ params_vec.push_back(padding[0].item());
+ }
+ for (int i = 0; i < dilation_x_kSpatialDim.size(); i++) {
+ auto dilation = dilation_x_kSpatialDim.get(i);
+ params_vec.push_back(dilation[0].item());
+ }
+ // output_padding does not exist in v1, so we fill in a default value
+ for (int i = 0; i < kSpatialDim; i++) {
+ params_vec.push_back(0);
+ }
+ params_vec.push_back(groups[0].item());
+ // transpose does not exist in v1, so we fill in a default value
+ params_vec.push_back(0);
+ int64_t vec_size = params_vec.size();
+ at::Tensor params_tensor = at::from_blob(params_vec.data(),
+ {vec_size}, at::TensorOptions().dtype(at::kShort))
+ // clone to retain ownership of the data
+ .clone();
+
+ non_optional.emplace_back(std::move(params_tensor));
+ non_optional.emplace_back(std::move(weight));
+ optional.emplace_back(std::move(bias));
+
+ return std::tie(version, non_optional, optional);
+ } else if (version == 2) {
+ // version 2
+ return v.to();
+ } else {
+ TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
+ version);
+ }
+}
+
+template
+ConvParamsSerializationType serialize_conv(
+ const c10::intrusive_ptr>& params) {
+
+ std::string version = "2";
+ std::vector non_optional;
+ std::vector> optional;
+
+ // create a packed int8_t tensor for conv params
+ std::vector params_vec;
+ params_vec.push_back(kSpatialDim);
+ auto stride = params->stride().vec();
+ params_vec.insert(params_vec.end(), stride.begin(), stride.end());
+ auto padding = params->padding().vec();
+ params_vec.insert(params_vec.end(), padding.begin(), padding.end());
+ auto dilation = params->dilation().vec();
+ params_vec.insert(params_vec.end(), dilation.begin(), dilation.end());
+ // output_padding is not implemented yet, so we fill in a default value
+ for (int i = 0; i < kSpatialDim; i++) {
+ params_vec.push_back(0);
+ }
+ params_vec.push_back(params->groups());
+ // transpose is not implemented yet, so we fill in a default value
+ params_vec.push_back(0);
+ int64_t vec_size = params_vec.size();
+ at::Tensor params_tensor = at::from_blob(
+ params_vec.data(), {vec_size},
+ at::TensorOptions().dtype(at::kShort))
+ // clone to retain ownership of the data
+ .clone();
+
+ at::Tensor weight;
+ c10::optional bias;
+ std::tie(weight, bias) = params->unpack();
+
+ non_optional.emplace_back(std::move(params_tensor));
+ non_optional.emplace_back(std::move(weight));
+ optional.emplace_back(std::move(bias));
+
+ return std::tie(version, non_optional, optional);
+}
+
+template
+ConvParamsSerializationTypeLegacy serialize_conv_legacy(
+ const c10::intrusive_ptr>& params) {
+ at::Tensor weight;
+ c10::optional bias;
+ std::tie(weight, bias) = params->unpack();
+ torch::List stride;
+ torch::List padding;
+ torch::List dilation;
+ at::Tensor groups;
+ for (int64_t s : params->stride()) {
+ stride.emplace_back(at::tensor(s));
+ }
+ for (int64_t p : params->padding()) {
+ padding.emplace_back(at::tensor(p));
+ }
+ for (int64_t d : params->dilation()) {
+ dilation.emplace_back(at::tensor(d));
+ }
+ groups = at::tensor(params->groups());
+ return std::make_tuple(
+ std::move(weight),
+ std::move(bias),
+ stride,
+ padding,
+ dilation,
+ groups);
+}
+
+template
+c10::intrusive_ptr> deserialize_conv(
+ ConvParamsSerializationType state) {
+
+ std::string version;
+ std::vector non_optional;
+ std::vector> optional;
+
+ std::tie(version, non_optional, optional) = state;
+ TORCH_INTERNAL_ASSERT(version == "2", "Unexpected serialized qconv version: ",
+ version);
+
+ at::Tensor conv_params_packed = non_optional[0];
+ at::Tensor weight = non_optional[1];
+ c10::optional bias = optional[0];
+
+ torch::List stride, padding, dilation;
+ // skip kSpatialDim
+ int idx = 1;
+ for (int i = 0; i < kSpatialDim; ++i) {
+ stride.emplace_back(conv_params_packed[idx].item());
+ idx++;
+ }
+ for (int i = 0; i < kSpatialDim; ++i) {
+ padding.emplace_back(conv_params_packed[idx].item());
+ idx++;
+ }
+ for (int i = 0; i < kSpatialDim; ++i) {
+ dilation.emplace_back(conv_params_packed[idx].item());
+ idx++;
+ }
+ // output_padding is not implemented yet, so we skip the entries
+ for (int i = 0; i < kSpatialDim; ++i) {
+ // do nothing
+ idx++;
+ }
+ int64_t groups = conv_params_packed[idx].item();
+ idx++;
+ // transpose is not implemented yet, so we skip the entry
+ idx++;
+ TORCH_INTERNAL_ASSERT(idx == conv_params_packed.numel(),
+ "Unexpected length of conv_params_packed, expected ",
+ idx,
+ " got ",
+ conv_params_packed.numel());
+
+ auto& ctx = at::globalContext();
+
+#ifdef USE_FBGEMM
+ if (ctx.qEngine() == at::QEngine::FBGEMM) {
+ return PackedConvWeight::prepack(
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups
+ );
+ }
+#endif // USE_FBGEMM
+#ifdef USE_PYTORCH_QNNPACK
+ if (ctx.qEngine() == at::QEngine::QNNPACK) {
+ TORCH_CHECK(
+ kSpatialDim == 2,
+ "prepack/__setstate__: QNNPACK only supports Conv2d "
+ "now.");
+ return PackedConvWeightsQnnp::prepack(
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups
+ );
+ }
+#endif // USE_PYTORCH_QNNPACK
+TORCH_CHECK(
+ false,
+ "Didn't find engine for when deserializing ConvPackedParams: ",
+ toString(ctx.qEngine()));
+}
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
index 48fdc04fe159409..6b93d50104f11a5 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
@@ -1,10 +1,11 @@
-#include
-#include
-#include
-#include
-
#include
#include
+
+#include
+#include
+#include
+#include
+#include
#include
#include
@@ -213,99 +214,27 @@ Tensor ConvertToChannelsLast3dTensor(const Tensor& src) {
template
CAFFE2_API torch::class_> register_conv_params() {
- using SerializationType = std::tuple<
- at::Tensor,
- c10::optional,
- // these are meant to be torch::List but
- // it's not supported by onnx, so we'll use Tensor as
- // a workaround
- torch::List,
- torch::List,
- torch::List,
- at::Tensor>;
static auto register_conv_params =
torch::class_>(
"quantized", "Conv" + c10::to_string(kSpatialDim) + "dPackedParamsBase")
.def_pickle(
+ /*
+ [](const c10::intrusive_ptr>& params)
+ -> ConvParamsSerializationType { // __getstate__
+ return serialize_conv(params);
+ },
+ */
+ // TODO (#43649): switch this to serialize_conv
[](const c10::intrusive_ptr>& params)
- -> SerializationType { // __getstate__
- at::Tensor weight;
- c10::optional bias;
- std::tie(weight, bias) = params->unpack();
- torch::List stride;
- torch::List padding;
- torch::List dilation;
- at::Tensor groups;
- for (int64_t s : params->stride()) {
- stride.emplace_back(at::tensor(s));
- }
- for (int64_t p : params->padding()) {
- padding.emplace_back(at::tensor(p));
- }
- for (int64_t d : params->dilation()) {
- dilation.emplace_back(at::tensor(d));
- }
- groups = at::tensor(params->groups());
- return std::make_tuple(
- std::move(weight),
- std::move(bias),
- stride,
- padding,
- dilation,
- groups);
+ -> ConvParamsSerializationTypeLegacy { // __getstate__
+ return serialize_conv_legacy(params);
},
- [](SerializationType state)
+ // __setstate__ takes c10::IValue because we support parsing historical
+ // serialization versions.
+ [](c10::IValue v)
-> c10::intrusive_ptr> { // __setstate__
- at::Tensor weight;
- c10::optional