diff --git a/.circleci/config.yml b/.circleci/config.yml
index 0716e516518b..d19c08b2b0b6 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -11,6 +11,9 @@ parameters:
run_binary_tests:
type: boolean
default: false
+ run_build:
+ type: boolean
+ default: true
docker_config_defaults: &docker_config_defaults
user: jenkins
@@ -9762,6 +9765,7 @@ workflows:
only:
- postnightly
executor: windows-with-nvidia-gpu
+ when: << pipeline.parameters.run_build >>
ecr_gc:
triggers:
- schedule:
diff --git a/.circleci/generate_config_yml.py b/.circleci/generate_config_yml.py
index f1af924bd3e2..a836d2e510a6 100755
--- a/.circleci/generate_config_yml.py
+++ b/.circleci/generate_config_yml.py
@@ -112,7 +112,10 @@ def gen_build_workflows_tree():
"when": r"<< pipeline.parameters.run_binary_tests >>",
"jobs": [f() for f in binary_build_functions],
},
- "build": {"jobs": [f() for f in build_workflows_functions]},
+ "build": {
+ "when": r"<< pipeline.parameters.run_build >>",
+ "jobs": [f() for f in build_workflows_functions]
+ },
}
}
diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh
index 0b2e60b48f8e..26cc77c8ff9c 100755
--- a/.circleci/scripts/binary_linux_test.sh
+++ b/.circleci/scripts/binary_linux_test.sh
@@ -51,7 +51,14 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then
else
cu_ver="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4}"
fi
- retry conda install \${EXTRA_CONDA_FLAGS} -yq -c nvidia -c pytorch "cudatoolkit=\${cu_ver}"
+ (
+ # For some reason conda likes to re-activate the conda environment when attempting this install
+ # which means that a deactivate is run and some variables might not exist when that happens,
+ # namely CONDA_MKL_INTERFACE_LAYER_BACKUP from libblas so let's just ignore unbound variables when
+ # it comes to the conda installation commands
+ set +u
+ retry conda install \${EXTRA_CONDA_FLAGS} -yq -c nvidia -c pytorch "cudatoolkit=\${cu_ver}"
+ )
fi
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
pip install "\$pkg"
diff --git a/.circleci/verbatim-sources/header-section.yml b/.circleci/verbatim-sources/header-section.yml
index 26205a0cccba..43d4c94ee5ed 100644
--- a/.circleci/verbatim-sources/header-section.yml
+++ b/.circleci/verbatim-sources/header-section.yml
@@ -11,6 +11,9 @@ parameters:
run_binary_tests:
type: boolean
default: false
+ run_build:
+ type: boolean
+ default: true
docker_config_defaults: &docker_config_defaults
user: jenkins
diff --git a/.github/pytorch-circleci-labels.yml b/.github/pytorch-circleci-labels.yml
index ccdf2e876af1..3a9eeca0abcc 100644
--- a/.github/pytorch-circleci-labels.yml
+++ b/.github/pytorch-circleci-labels.yml
@@ -9,3 +9,5 @@ labels_to_circle_params:
- release/.*
tags:
- v[0-9]+(\.[0-9]+)*-rc[0-9]+
+ set_to_false:
+ - run_build
diff --git a/.github/workflows/update_s3_htmls.yml b/.github/workflows/update_s3_htmls.yml
index 92f9a66a0fd8..f2320ce2fcbf 100644
--- a/.github/workflows/update_s3_htmls.yml
+++ b/.github/workflows/update_s3_htmls.yml
@@ -9,6 +9,7 @@ on:
jobs:
update-html:
runs-on: ubuntu-latest
+ if: ${{ github.repository_owner == 'pytorch' }}
strategy:
matrix:
prefix: ["whl", "whl/test", "whl/nightly"]
diff --git a/.jenkins/pytorch/README.md b/.jenkins/pytorch/README.md
index ea6c6dd40f68..9fd68ecf7f15 100644
--- a/.jenkins/pytorch/README.md
+++ b/.jenkins/pytorch/README.md
@@ -10,9 +10,9 @@ it is very easy to run these tests yourself:
``registry.pytorch.org/pytorch/pytorch-$BUILD_ENVIRONMENT:$DOCKER_VERSION``,
where ``$BUILD_ENVIRONMENT`` is one of the build environments
enumerated in
- [pytorch-dockerfiles](https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh)
+ [pytorch-dockerfiles](https://github.com/pytorch/pytorch/blob/master/.circleci/docker/build.sh). The dockerfile used by jenkins can be found under the `.circle` [directory](https://github.com/pytorch/pytorch/blob/master/.circleci/docker)
-2. Run ``docker -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and
+2. Run ``docker run -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and
run one of the scripts in this directory.
The Docker images are designed so that any "reasonable" build commands
@@ -38,5 +38,5 @@ mechanisms we use:
build scripts.
- We reroute well known paths like `/usr/bin/gcc` to alternate
- implementations with `update-alternatives, instead of setting
+ implementations with `update-alternatives`, instead of setting
`CC` and `CXX` in our implementations.
diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh
index 17e7e9fa3445..47d13f2908d0 100755
--- a/.jenkins/pytorch/codegen-test.sh
+++ b/.jenkins/pytorch/codegen-test.sh
@@ -48,13 +48,6 @@ python -m tools.autograd.gen_autograd \
"$OUT"/autograd \
tools/autograd
-# unboxing_wrappers codegen (called by torch codegen but can run independently)
-mkdir -p "$OUT"/unboxing_wrappers
-python -m tools.jit.gen_unboxing_wrappers \
- "$OUT"/torch/share/ATen/Declarations.yaml \
- "$OUT"/unboxing_wrappers \
- tools/jit/templates
-
# annotated_fn_args codegen (called by torch codegen but can run independently)
mkdir -p "$OUT"/annotated_fn_args
python -m tools.autograd.gen_annotated_fn_args \
diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh
index 0c34ddcc6179..24ec02c76df5 100755
--- a/.jenkins/pytorch/macos-test.sh
+++ b/.jenkins/pytorch/macos-test.sh
@@ -9,11 +9,6 @@ pip install -q hypothesis "librosa>=0.6.2" "numba<=0.49.1" psutil
# TODO move this to docker
pip install unittest-xml-reporting pytest
-# faulthandler become built-in since 3.3
-if [[ ! $(python -c "import sys; print(int(sys.version_info >= (3, 3)))") == "1" ]]; then
- pip install -q faulthandler
-fi
-
if [ -z "${IN_CI}" ]; then
rm -rf ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch*
fi
diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
index a052a1b67d59..ed6482890993 100644
--- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
+++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
@@ -41,8 +41,6 @@ popd
:: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest coverage
if %errorlevel% neq 0 ( exit /b %errorlevel% )
-:: No need to install faulthandler since we only test Python >= 3.6 on Windows
-:: faulthandler is builtin since Python 3.3
set DISTUTILS_USE_SDK=1
diff --git a/BUILD.bazel b/BUILD.bazel
index b3faea487965..2b4636d850c9 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -193,9 +193,6 @@ libtorch_cpp_generated_sources = [
"torch/csrc/autograd/generated/Functions.h",
"torch/csrc/autograd/generated/Functions.cpp",
"torch/csrc/autograd/generated/variable_factories.h",
- "torch/csrc/jit/generated/generated_unboxing_wrappers_0.cpp",
- "torch/csrc/jit/generated/generated_unboxing_wrappers_1.cpp",
- "torch/csrc/jit/generated/generated_unboxing_wrappers_2.cpp",
]
libtorch_python_generated_sources = [
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ba862b5a4d5f..3df73f8a3041 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -173,6 +173,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF)
cmake_dependent_option(
USE_NCCL "Use NCCL" ON
"USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
+cmake_dependent_option(USE_RCCL "Use RCCL" ON
+ USE_NCCL OFF)
cmake_dependent_option(
USE_STATIC_NCCL "Use static NCCL" OFF
"USE_NCCL" OFF)
@@ -316,7 +318,7 @@ set(OP_DEPENDENCY "" CACHE STRING
# symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk
# https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu
if(LINUX)
- set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed")
+ set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed ${CMAKE_SHARED_LINKER_FLAGS}")
endif()
if(MSVC)
diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml
index a83bf223bdaf..abdd9a8d986a 100644
--- a/android/test_app/app/src/main/AndroidManifest.xml
+++ b/android/test_app/app/src/main/AndroidManifest.xml
@@ -18,4 +18,10 @@
+
+
+
+
diff --git a/aten/conda/meta.yaml b/aten/conda/meta.yaml
index d8096fc73a0f..a502690a5447 100644
--- a/aten/conda/meta.yaml
+++ b/aten/conda/meta.yaml
@@ -24,7 +24,7 @@ requirements:
- mkl # [not osx]
about:
- home: https://github.com/zdevito/ATen
+ home: https://github.com/pytorch/pytorch
license: BSD
summary: A TENsor library for C++14
diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h
index ae95ef43f21c..8d29a9204420 100644
--- a/aten/src/ATen/ATen.h
+++ b/aten/src/ATen/ATen.h
@@ -31,3 +31,4 @@
#include
#include
#include
+#include
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index 419c454257d8..2cd7cac4e71b 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -287,6 +287,25 @@ Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
return self_physical.getPhysicalToLogicalMap().apply(result);
}
+Tensor trace_batching_rule(const Tensor& self) {
+ auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
+ // Batched Diagonal View
+ auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
+ auto result = at::sum(self_diag, -1);
+ return self_physical.getPhysicalToLogicalMap().apply(result);
+}
+
+Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
+ auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
+ auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
+ // Batched Diagonal View
+ auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
+ // Append a dimension of size one to the grad output
+ auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
+ grad_input_diag.copy_(grad_physical_tensor);
+ return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
+}
+
Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
// PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
// for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
@@ -996,7 +1015,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("_add_batch_dim", native::_add_batch_dim);
m.impl("_remove_batch_dim", native::_remove_batch_dim);
- m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule);
+ m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);
@@ -1029,6 +1048,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("squeeze", squeeze_batching_rule);
m.impl("squeeze.dim", squeeze_dim_batching_rule);
m.impl("t", native::t); // composite wrt autograd
+ m.impl("trace", trace_batching_rule);
m.impl("transpose.int", transpose_int_batching_rule);
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
@@ -1089,6 +1109,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
#undef TO_BATCHING_RULE
m.impl("clone", clone_batching_rule);
+ using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
@@ -1115,6 +1136,12 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
m.impl("sigmoid_backward", binary_pointwise_batching_rule);
+ m.impl(
+ "threshold_backward",
+ binary_pointwise_batching_rule<
+ TensorTensorScalarType,
+ at::threshold_backward,
+ Scalar>);
// for at::result_type, call the native::result_type implementation.
// We don't have to do anything special because native::result_type operates
@@ -1150,6 +1177,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
// backward operators
m.impl("select_backward", select_backward_batching_rule);
m.impl("slice_backward", slice_backward_batching_rule);
+ m.impl("trace_backward", trace_backward_batching_rule);
m.impl("diagonal_backward", diagonal_backward_batching_rule);
// Tensor.new_* operators
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index fd3c95f2573b..6fedef185b21 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -72,7 +72,7 @@ file(GLOB metal_h "metal/*.h")
file(GLOB metal_cpp "metal/*.cpp")
file(GLOB_RECURSE native_metal_h "native/metal/*.h")
file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm")
-file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp")
+file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm" "native/metal/*.cpp")
EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs})
file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h")
file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp")
diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp
index bfa4a2a8f72f..ff4a2f1c61e2 100644
--- a/aten/src/ATen/CPUGeneratorImpl.cpp
+++ b/aten/src/ATen/CPUGeneratorImpl.cpp
@@ -1,4 +1,6 @@
#include
+#include
+#include
#include
#include
@@ -6,6 +8,42 @@ namespace at {
namespace detail {
+/**
+ * CPUGeneratorImplStateLegacy is a POD class needed for memcpys
+ * in torch.get_rng_state() and torch.set_rng_state().
+ * It is a legacy class and even though it is replaced with
+ * at::CPUGeneratorImpl, we need this class and some of its fields
+ * to support backward compatibility on loading checkpoints.
+ */
+struct CPUGeneratorImplStateLegacy {
+ /* The initial seed. */
+ uint64_t the_initial_seed;
+ int left; /* = 1; */
+ int seeded; /* = 0; */
+ uint64_t next;
+ uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */
+
+ /********************************/
+
+ /* For normal distribution */
+ double normal_x;
+ double normal_y;
+ double normal_rho;
+ int normal_is_valid; /* = 0; */
+};
+
+/**
+ * CPUGeneratorImplState is a POD class containing
+ * new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
+ * as a helper for torch.get_rng_state() and torch.set_rng_state()
+ * functions.
+ */
+struct CPUGeneratorImplState {
+ CPUGeneratorImplStateLegacy legacy_pod;
+ float next_float_normal_sample;
+ bool is_next_float_normal_sample_valid;
+};
+
/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
@@ -75,6 +113,128 @@ uint64_t CPUGeneratorImpl::seed() {
return random;
}
+/**
+ * Sets the internal state of CPUGeneratorImpl. The new internal state
+ * must be a strided CPU byte tensor and of the same size as either
+ * CPUGeneratorImplStateLegacy (for legacy CPU generator state) or
+ * CPUGeneratorImplState (for new state).
+ *
+ * FIXME: Remove support of the legacy state in the future?
+ */
+void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
+ using detail::CPUGeneratorImplState;
+ using detail::CPUGeneratorImplStateLegacy;
+
+ static_assert(std::is_pod::value, "CPUGeneratorImplStateLegacy is not a PODType");
+ static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType");
+
+ static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
+ static const size_t size_current = sizeof(CPUGeneratorImplState);
+ static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size");
+
+ detail::check_rng_state(new_state);
+
+ at::mt19937 engine;
+ auto float_normal_sample = c10::optional();
+ auto double_normal_sample = c10::optional();
+
+ // Construct the state of at::CPUGeneratorImpl based on input byte tensor size.
+ CPUGeneratorImplStateLegacy* legacy_pod;
+ auto new_state_size = new_state.numel();
+ if (new_state_size == size_legacy) {
+ legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data();
+ // Note that in CPUGeneratorImplStateLegacy, we didn't have float version
+ // of normal sample and hence we leave the c10::optional as is
+
+ // Update next_double_normal_sample.
+ // Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y)
+ // and a rho value (normal_rho). These three values were redundant and in the new
+ // DistributionsHelper.h, we store the actual extra normal sample, rather than three
+ // intermediate values.
+ if (legacy_pod->normal_is_valid) {
+ auto r = legacy_pod->normal_rho;
+ auto theta = 2.0 * M_PI * legacy_pod->normal_x;
+ // we return the sin version of the normal sample when in caching mode
+ double_normal_sample = c10::optional(r * ::sin(theta));
+ }
+ } else if (new_state_size == size_current) {
+ auto rng_state = (CPUGeneratorImplState*)new_state.data();
+ legacy_pod = &rng_state->legacy_pod;
+ // update next_float_normal_sample
+ if (rng_state->is_next_float_normal_sample_valid) {
+ float_normal_sample = c10::optional(rng_state->next_float_normal_sample);
+ }
+
+ // Update next_double_normal_sample.
+ // Note that in getRNGState, we now return the actual normal sample in normal_y
+ // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho
+ // are squashed to 0.0.
+ if (legacy_pod->normal_is_valid) {
+ double_normal_sample = c10::optional(legacy_pod->normal_y);
+ }
+ } else {
+ AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy,
+ " or a CPUGeneratorImplState of size ", size_current,
+ " but found the input RNG state size to be ", new_state_size);
+ }
+
+ // construct engine_
+ // Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our
+ // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are
+ // doing a std::copy.
+ at::mt19937_data_pod rng_data;
+ std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin());
+ rng_data.seed_ = legacy_pod->the_initial_seed;
+ rng_data.left_ = legacy_pod->left;
+ rng_data.seeded_ = legacy_pod->seeded;
+ rng_data.next_ = static_cast(legacy_pod->next);
+ engine.set_data(rng_data);
+ TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state");
+ this->engine_ = engine;
+ this->next_float_normal_sample_ = float_normal_sample;
+ this->next_double_normal_sample_ = double_normal_sample;
+}
+
+/**
+ * Gets the current internal state of CPUGeneratorImpl. The internal
+ * state is returned as a CPU byte tensor.
+ */
+c10::intrusive_ptr CPUGeneratorImpl::get_state() const {
+ using detail::CPUGeneratorImplState;
+
+ static const size_t size = sizeof(CPUGeneratorImplState);
+ static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType");
+
+ auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
+ auto rng_state = state_tensor.data_ptr();
+
+ // accumulate generator data to be copied into byte tensor
+ auto accum_state = std::make_unique();
+ auto rng_data = this->engine_.data();
+ accum_state->legacy_pod.the_initial_seed = rng_data.seed_;
+ accum_state->legacy_pod.left = rng_data.left_;
+ accum_state->legacy_pod.seeded = rng_data.seeded_;
+ accum_state->legacy_pod.next = rng_data.next_;
+ std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state));
+ accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy
+ accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy
+ accum_state->legacy_pod.normal_is_valid = false;
+ accum_state->legacy_pod.normal_y = 0.0;
+ accum_state->next_float_normal_sample = 0.0f;
+ accum_state->is_next_float_normal_sample_valid = false;
+ if (this->next_double_normal_sample_) {
+ accum_state->legacy_pod.normal_is_valid = true;
+ accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_);
+ }
+ if (this->next_float_normal_sample_) {
+ accum_state->is_next_float_normal_sample_valid = true;
+ accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
+ }
+
+ memcpy(rng_state, accum_state.get(), size);
+ return state_tensor.getIntrusivePtr();
+}
+
/**
* Gets the DeviceType of CPUGeneratorImpl.
* Used for type checking during run time.
diff --git a/aten/src/ATen/CPUGeneratorImpl.h b/aten/src/ATen/CPUGeneratorImpl.h
index eceb338966fd..f8b43a04c73c 100644
--- a/aten/src/ATen/CPUGeneratorImpl.h
+++ b/aten/src/ATen/CPUGeneratorImpl.h
@@ -17,6 +17,8 @@ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
+ void set_state(const c10::TensorImpl& new_state) override;
+ c10::intrusive_ptr get_state() const override;
static DeviceType device_type();
uint32_t random();
uint64_t random64();
diff --git a/aten/src/ATen/CUDAGeneratorImpl.h b/aten/src/ATen/CUDAGeneratorImpl.h
index 9a9febd01f8e..1179a049aa08 100644
--- a/aten/src/ATen/CUDAGeneratorImpl.h
+++ b/aten/src/ATen/CUDAGeneratorImpl.h
@@ -129,8 +129,10 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
+ void set_state(const c10::TensorImpl& new_state) override;
+ c10::intrusive_ptr get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
- uint64_t philox_offset_per_thread();
+ uint64_t philox_offset_per_thread() const;
void capture_prologue(int64_t* offset_extragraph);
uint64_t capture_epilogue();
PhiloxCudaState philox_cuda_state(uint64_t increment);
diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h
index 41252609953f..341e20cab1f3 100644
--- a/aten/src/ATen/Dispatch.h
+++ b/aten/src/ATen/Dispatch.h
@@ -10,6 +10,9 @@
#include
#include
+#ifdef XPLAT_MOBILE_BUILD
+#include
+#else
namespace at {
/**
* The method should_include_kernel_dtype() returns true/false
@@ -25,6 +28,7 @@ inline constexpr bool should_include_kernel_dtype(
return true;
}
}
+#endif
/**
* In the Facebook internal build (using BUCK), this macro is enabled by
@@ -93,26 +97,6 @@ inline constexpr bool should_include_kernel_dtype(
return __VA_ARGS__(); \
}
-// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
-// should be removed once the bfloat16 bringup is complete on other platforms.
-// This is supposed to be used as a wrapper around the lambda function passed to
-// the dispatch macro and will conditionally dispatch ops with bfloat16 type
-// only on ROCm.
-#if !defined(__HIP_PLATFORM_HCC__)
-#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) \
- if (std::is_same::value) { \
- AT_ERROR( \
- #NAME, \
- " not implemented for '", \
- toString(at::ScalarType::BFloat16), \
- "'"); \
- } else { \
- return __VA_ARGS__(); \
- }
-#else
-#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) return __VA_ARGS__()
-#endif
-
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp
index 07fc4e279557..261f6cdd46b5 100644
--- a/aten/src/ATen/ParallelOpenMP.cpp
+++ b/aten/src/ATen/ParallelOpenMP.cpp
@@ -1,4 +1,5 @@
#include
+#include
#if AT_PARALLEL_OPENMP
#include
diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h
index 3890662123a2..f6c3bbbe09cc 100644
--- a/aten/src/ATen/TensorIndexing.h
+++ b/aten/src/ATen/TensorIndexing.h
@@ -10,6 +10,8 @@
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include
+#include
+
namespace at {
namespace indexing {
@@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector&
(*dim_ptr)++;
};
-static inline std::vector typeConvertIndices(const Tensor& self, std::vector&& indices) {
- std::vector converted_inds(indices.size());
+static inline c10::List> typeConvertIndices(const Tensor& self, std::vector&& indices) {
+ c10::List> converted_inds;
+ converted_inds.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
const auto &ind = indices[i];
if (ind.defined()) {
- converted_inds[i] = ind.to(ind.options().device(self.device()));
+ converted_inds.push_back(ind.to(ind.options().device(self.device())));
} else {
- converted_inds[i] = std::move(indices[i]);
+ converted_inds.push_back(std::move(indices[i]));
}
}
return converted_inds;
diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h
index 5063beeb08b0..8fa085245459 100644
--- a/aten/src/ATen/VmapTransforms.h
+++ b/aten/src/ATen/VmapTransforms.h
@@ -96,8 +96,17 @@ struct VmapPhysicalToLogicalMap;
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
-// bit of `levels` specifies the minimum number of nested vmaps we are in at
+// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
+// For example, given:
+// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
+//
+// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
+// than or equal to 3.
+// bitset: 010100
+// ^
+// |
+// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset levels)
: levels_(levels), tensor_(tensor) {
diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp
index 8c82f965ef0f..9a2f34257c57 100644
--- a/aten/src/ATen/autocast_mode.cpp
+++ b/aten/src/ATen/autocast_mode.cpp
@@ -239,13 +239,9 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction::type::call);
-#define KERNEL_UNBOXED_ONLY(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
- m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
- &WrapFunction::type::call);
-
// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
-#define KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
- m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
+#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
+ m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction::type::call);
/*****************************************
@@ -367,20 +363,20 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional&, const c10::optional&, int64_t), fp32)
KERNEL(ADD_NS(dist), "dist", Tensor (const Tensor &, const Tensor &, Scalar), fp32)
KERNEL(ADD_NS(pdist), "pdist", Tensor (const Tensor &, double), fp32)
- KERNEL_UNBOXED_ONLY(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32)
+ KERNEL(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32)
KERNEL(ADD_NS(renorm), "renorm", Tensor (const Tensor &, Scalar, int64_t, Scalar), fp32)
// fp32_set_opt_dtype
KERNEL(ADD_NS(prod), "prod", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(prod), "prod.dim_int", Tensor (const Tensor &, int64_t, bool, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(softmax), "softmax.int", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(log_softmax), "log_softmax.int", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumprod), "cumprod", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumsum), "cumsum", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dtype", Tensor (const Tensor &, c10::optional, ScalarType), fp32_set_opt_dtype)
@@ -388,25 +384,25 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype)
// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
// norm does not implicitly promote, but be aware when adding new ops to this policy.
- KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional, ScalarType), fp32_append_dtype)
- KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional, IntArrayRef, bool), Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_append_dtype)
- KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional, DimnameList, bool), Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_append_dtype)
+ KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional, ScalarType), fp32_append_dtype)
+ KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional, IntArrayRef, bool), Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_append_dtype)
+ KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional, DimnameList, bool), Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_append_dtype)
// promote
KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote)
KERNEL(ADD_NS(addcmul), "addcmul", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote)
KERNEL(ADD_NS(atan2), "atan2", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(bilinear), "bilinear", Tensor (const Tensor &, const Tensor &, const Tensor &, const c10::optional&), promote)
KERNEL(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
- KERNEL_UNBOXED_ONLY(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote)
+ KERNEL(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote)
KERNEL(ADD_NS(_cat), "_cat", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
- KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
+ KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)
diff --git a/aten/src/ATen/core/Generator.cpp b/aten/src/ATen/core/Generator.cpp
new file mode 100644
index 000000000000..800f8c7c88ec
--- /dev/null
+++ b/aten/src/ATen/core/Generator.cpp
@@ -0,0 +1,16 @@
+#include
+#include
+#include
+
+namespace at {
+
+void Generator::set_state(const at::Tensor& new_state) {
+ TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed");
+ this->impl_->set_state(*new_state.unsafeGetTensorImpl());
+}
+
+at::Tensor Generator::get_state() const {
+ return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
+}
+
+} // namespace at
diff --git a/aten/src/ATen/core/Generator.h b/aten/src/ATen/core/Generator.h
index de3f6e46f8f2..b5bbb2fe3c74 100644
--- a/aten/src/ATen/core/Generator.h
+++ b/aten/src/ATen/core/Generator.h
@@ -56,6 +56,8 @@
namespace at {
+class Tensor;
+
struct TORCH_API Generator {
Generator() {}
@@ -96,6 +98,12 @@ struct TORCH_API Generator {
uint64_t seed() { return impl_->seed(); }
+ // Implementation not inlined to prevent cycle reference between
+ // `ATen/core/Generator.h` and `ATen/core/Tensor.h`
+ void set_state(const at::Tensor& new_state);
+
+ at::Tensor get_state() const;
+
std::mutex& mutex() {
return impl_->mutex_;
}
@@ -130,4 +138,24 @@ Generator make_generator(Args&&... args) {
return Generator(c10::make_intrusive(std::forward(args)...));
}
+namespace detail {
+
+/**
+ * Helper function for checking the validity of new random generator
+ * state. Right now following conditions are checked:
+ *
+ * - The new state tensor must be a torch.ByteTensor
+ * - Data of the new state tensor must be contiguous
+ */
+static inline void check_rng_state(const c10::TensorImpl& new_state) {
+ TORCH_CHECK_TYPE(
+ new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
+ "RNG state must be a torch.ByteTensor"
+ );
+
+ TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
+}
+
+} // namespace detail
+
} // namespace at
diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h
index 40f733784fe5..f911722c51e1 100644
--- a/aten/src/ATen/core/List.h
+++ b/aten/src/ATen/core/List.h
@@ -243,7 +243,7 @@ class List final {
* Example:
* List a({2, 3, 4});
*/
- explicit List(std::initializer_list initial_values);
+ List(std::initializer_list initial_values);
explicit List(ArrayRef initial_values);
/**
diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h
index 3cbd7a310275..ab3ddae55770 100644
--- a/aten/src/ATen/core/List_inl.h
+++ b/aten/src/ATen/core/List_inl.h
@@ -1,7 +1,7 @@
#pragma once
+#include
#include
-#include
namespace c10 {
@@ -50,7 +50,17 @@ List::List(TypePtr elementType)
namespace impl {
template
List toTypedList(impl::GenericList list) {
- TORCH_INTERNAL_ASSERT(*getTypePtr() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch.");
+ // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
+ // because upcasting would allow people to add types into the new list that would break the old list.
+ // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
+ // allow upcasting. This can be a perf improvement since we can cast List to List>
+ // without having to copy it. This is also used to provide backwards compatibility with some old models
+ // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
+ // as List before we changed that argument to be List>. When deserializing, we
+ // have list.use_count() == 1 and can deserialize the List directly as List>.
+ TORCH_CHECK(*list.impl_->elementType == *getTypePtr()
+ || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr()))
+ , "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch.");
return List(std::move(list.impl_));
}
@@ -312,3 +322,5 @@ void List::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}
+
+#include
diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h
index b49d94bba1c8..d33f3d575177 100644
--- a/aten/src/ATen/core/Variadic.h
+++ b/aten/src/ATen/core/Variadic.h
@@ -6,6 +6,7 @@
#include
#include
+#include
namespace at {
@@ -56,6 +57,15 @@ struct IterArgs {
}
}
+ template
+ void operator()(const torch::List& args) {
+ for (const auto& arg : args) {
+ self()(arg);
+ if (self().short_circuit())
+ return;
+ }
+ }
+
// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template
diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp
index f84352ebee1f..58c35557018c 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.cpp
+++ b/aten/src/ATen/core/boxing/KernelFunction.cpp
@@ -57,25 +57,4 @@ bool KernelFunction::_equalsBoxedAndUnboxed(const KernelFunction& other) const {
unboxed_kernel_func_ == other.unboxed_kernel_func_;
}
-void KernelFunction::checkBoxedKernel(const OperatorHandle& opHandle) const {
- if (C10_UNLIKELY(boxed_kernel_func_ == nullptr)) {
- if (unboxed_kernel_func_ == nullptr) {
- TORCH_INTERNAL_ASSERT(
- false,
- "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction.",
- " opname: ",
- opHandle.operator_name(),
- " If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`.");
- } else {
- // TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this case should be impossible.
- TORCH_INTERNAL_ASSERT(
- false,
- "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call().",
- " opname: ",
- opHandle.operator_name(),
- " If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`.");
- }
- }
-}
-
} // namespace c10
diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h
index 6817907b12b1..ddbbd912777a 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.h
+++ b/aten/src/ATen/core/boxing/KernelFunction.h
@@ -123,26 +123,6 @@ class TORCH_API KernelFunction final {
template
static KernelFunction makeFromUnboxedFunctor(std::unique_ptr kernelFunctor);
- /**
- * Create a KernelFunction from an unboxed functor and prevent creation of an
- * unboxing-wrapper. This means that you cannot call this KernelFunction
- * using KernelFunction::callBoxed()
- *
- * This is necessary because our unboxing wrappers don't work for all types
- * yet, so if you want to use one of these types as function arguments,
- * you need to use makeFromUnboxedOnlyFunctor.
- *
- * Example:
- *
- * > class MyFunctor final {
- * > public:
- * > Tensor operator()(Tensor a, Tensor b) {...}
- * > };
- * > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::make_unique());
- */
- template
- static KernelFunction makeFromUnboxedOnlyFunctor(std::unique_ptr kernelFunctor);
-
/**
* Create a KernelFunction from an unboxed function.
* This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
@@ -158,23 +138,6 @@ class TORCH_API KernelFunction final {
template
static KernelFunction makeFromUnboxedFunction(FuncPtr);
- /**
- * Create a KernelFunction from an unboxed function and prevent creation of an
- * unboxing-wrapper. This means that you cannot call this KernelFunction
- * using KernelFunction::callBoxed()
- *
- * This is necessary because our unboxing wrappers don't work for all types
- * yet, so if you want to use one of these types as function arguments,
- * you need to use makeFromUnboxedOnlyFunctor.
- *
- * Example:
- *
- * > Tensor unboxed_func(Tensor a, Tensor b) {...}
- * > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction();
- */
- template
- static KernelFunction makeFromUnboxedOnlyFunction(FuncPtr);
-
/**
* Create a KernelFunction from an unboxed function.
* KernelFunction::makeFromUnboxedFunction is usually a better choice than
@@ -189,9 +152,6 @@ class TORCH_API KernelFunction final {
template
static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
- template
- static KernelFunction makeFromUnboxedOnlyRuntimeFunction(FuncType* func);
-
static KernelFunction makeFallthrough();
static KernelFunction makeAmbiguousAutogradOther();
static KernelFunction makeNamedNotSupported();
@@ -213,12 +173,6 @@ class TORCH_API KernelFunction final {
// For testing internal invariants only
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
- // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic. This can be done once https://github.com/pytorch/pytorch/issues/32366 is fixed.
- void setManuallyBoxedKernel_(InternalBoxedKernelFunction* func);
-
private:
explicit KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func);
@@ -226,8 +180,6 @@ class TORCH_API KernelFunction final {
template
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, Stack* stack);
- void checkBoxedKernel(const OperatorHandle& opHandle) const;
-
OperatorKernel* getFunctor_() const;
std::shared_ptr functor_;
diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h
index 82a65fa27ffb..b248e54a6f94 100644
--- a/aten/src/ATen/core/boxing/KernelFunction_impl.h
+++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h
@@ -23,8 +23,7 @@ inline void KernelFunction::make_boxed_function(OperatorKernel*, const OperatorH
}
inline bool KernelFunction::isValid() const {
- // TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this should only check boxed_kernel_func_.
- return boxed_kernel_func_ != nullptr || unboxed_kernel_func_ != nullptr;
+ return boxed_kernel_func_ != nullptr;
}
inline bool KernelFunction::isFallthrough() const {
@@ -32,7 +31,10 @@ inline bool KernelFunction::isFallthrough() const {
}
inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, Stack* stack) const {
- checkBoxedKernel(opHandle);
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ boxed_kernel_func_ != nullptr,
+ "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction."
+ );
(*boxed_kernel_func_)(functor_.get(), opHandle, stack);
}
@@ -111,21 +113,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr
-inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr kernelFunctor) {
- // TODO We want to get rid of kernels that have only an unboxed function pointer.
- // All kernels should have a boxed pointer.
-
- static_assert(guts::is_functor::value, "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor.");
- static_assert(std::is_base_of::value, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
-
- return KernelFunction(
- std::move(kernelFunctor),
- nullptr, // Don't create a boxed kernel for this
- reinterpret_cast(&impl::wrap_kernel_functor_unboxed::call)
- );
-}
-
template
inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) {
static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
@@ -144,26 +131,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr)
#endif
}
-template
-inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunction(FuncPtr func_ptr) {
- // TODO We want to get rid of kernels that have only an unboxed function pointer.
- // All kernels should have a boxed pointer.
- static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
- static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
- static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
-
-#if !defined(C10_MOBILE)
- return makeFromUnboxedOnlyFunctor::type> (
- guts::make_unique_base::type>()
- );
-#else
- // On mobile, we rather want to optimize for binary size than for performance,
- // so let's not inline the kernel into the wrapper but use makeFromUnboxedOnlyRuntimeFunction
- // instead.
- return makeFromUnboxedOnlyRuntimeFunction(func_ptr.func_ptr());
-#endif
-}
-
template
inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) {
static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
@@ -175,17 +142,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* f
);
}
-template
-inline KernelFunction KernelFunction::makeFromUnboxedOnlyRuntimeFunction(FuncType* func) {
- static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
- static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
- TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
-
- return makeFromUnboxedOnlyFunctor>>(
- guts::make_unique_base>>(func)
- );
-}
-
template
inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
@@ -212,14 +168,4 @@ inline std::enable_if_t>::value,
);
}
-inline void KernelFunction::setManuallyBoxedKernel_(InternalBoxedKernelFunction* func) {
- if (boxed_kernel_func_ == &fallthrough_kernel) {
- // special case no-op
- return;
- }
- TORCH_INTERNAL_ASSERT(boxed_kernel_func_ == nullptr, "Tried to set a manually boxed kernel for a kernel that already has a boxed kernel set.");
- TORCH_INTERNAL_ASSERT(unboxed_kernel_func_ != nullptr, "Tried to set a manually boxed kernel for an invalid KernelFunction.");
- boxed_kernel_func_ = func;
-}
-
}
diff --git a/aten/src/ATen/core/boxing/KernelFunction_test.cpp b/aten/src/ATen/core/boxing/KernelFunction_test.cpp
index 8ba50db14a2b..e17efab10ba5 100644
--- a/aten/src/ATen/core/boxing/KernelFunction_test.cpp
+++ b/aten/src/ATen/core/boxing/KernelFunction_test.cpp
@@ -544,26 +544,6 @@ TEST(KernelFunctionTest, givenUnboxedFunctor_withoutReturn_whenCallingUnboxed_th
kernels::expectUnboxedCallingWithoutReturnWorks(func);
}
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withoutReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectUnboxedCallingWithReturnWorks(func);
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withoutReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectUnboxedCallingWithoutReturnWorks(func);
-}
-
TEST(KernelFunctionTest, givenUnboxedFunction_withReturn_whenCallingBoxed_thenWorks) {
KernelFunction func = KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernels::unboxed_function_with_return));
kernels::expectBoxedCallingWithReturnWorks(func);
@@ -584,26 +564,6 @@ TEST(KernelFunctionTest, givenUnboxedFunction_withoutReturn_whenCallingUnboxed_t
kernels::expectUnboxedCallingWithoutReturnWorks(func);
}
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_with_return));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withoutReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_without_return));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_with_return));
- kernels::expectUnboxedCallingWithReturnWorks(func);
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withoutReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_without_return));
- kernels::expectUnboxedCallingWithoutReturnWorks(func);
-}
-
TEST(KernelFunctionTest, givenUnboxedRuntimeFunction_withReturn_whenCallingBoxed_thenWorks) {
KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&kernels::unboxed_function_with_return);
kernels::expectBoxedCallingWithReturnWorks(func);
diff --git a/aten/src/ATen/core/builtin_function.h b/aten/src/ATen/core/builtin_function.h
index 8bfb4f7e9d16..adeaa1039638 100644
--- a/aten/src/ATen/core/builtin_function.h
+++ b/aten/src/ATen/core/builtin_function.h
@@ -101,8 +101,17 @@ struct BuiltinOpFunction : public Function {
}
std::string pretty_print_schema() const override {
+ #ifdef __NVCC__
+ // Disable the "statement is unreachable" warning
+ #pragma diag_suppress code_is_unreachable
+ #endif
+
TORCH_INTERNAL_ASSERT(false);
return "";
+
+ #ifdef __NVCC__
+ #pragma diag_default code_is_unreachable
+ #endif
}
Function& setSchema(c10::FunctionSchema schema) override {
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp
index 5e3e91afbb45..270cffaf6d1f 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.cpp
+++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp
@@ -295,12 +295,6 @@ void Dispatcher::checkInvariants() const {
}
}
-void Dispatcher::setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func) {
- std::lock_guard lock(mutex_);
- op.operatorIterator_->op.setManuallyBoxedKernel_(*this, func);
- // NB: Do not need to set manually boxed kernel for backend fallbacks
-}
-
std::vector Dispatcher::findDanglingImpls() const {
return operatorLookupTable_.read([&] (const ska::flat_hash_map& operatorLookupTable) -> std::vector {
std::vector opsWithDanglingImpls;
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 60f9f9bd0579..d83653f75363 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -182,12 +182,6 @@ class TORCH_API Dispatcher final {
*/
RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
- // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete setBoxedKernelFor_ once all operators work with the templated boxing logic
- void setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func);
-
// ------------------------------------------------------------------------
//
// Listeners on registrations
@@ -310,7 +304,9 @@ class TORCH_API OperatorHandle {
// smuggle in a kernel that is typed incorrectly). For everything
// in core library this won't happen, because all the static registrations
// will be done by the time a typed() handle is acquired.
+#if !defined C10_MOBILE
operatorIterator_->op.assertSignatureIsCorrect();
+#endif
return TypedOperatorHandle(operatorIterator_);
}
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
index f0d7bc6968ed..7c3698beeb06 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
@@ -21,7 +21,6 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
, schema_()
, dispatchTable_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
-, manuallyBoxedKernel_()
, kernels_()
, cpp_signature_()
, is_observed_(ObservedOperators::isObserved(name_))
@@ -122,10 +121,6 @@ std::list::iterator OperatorEntry::registerKernel(
);
}
- if (manuallyBoxedKernel_.has_value()) {
- kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_);
- }
-
k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
std::list::iterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
@@ -331,19 +326,6 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
}
}
-void OperatorEntry::setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func) {
- TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_);
- manuallyBoxedKernel_ = func;
-
- for (auto& kv : kernels_) {
- for (auto& k : kv.second) {
- k.kernel.setManuallyBoxedKernel_(func);
- }
- }
- // Refresh entries in dispatchTable_
- updateDispatchTableFull_(dispatcher);
-}
-
void OperatorEntry::checkInvariants() const {
if (schema_) {
TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState());
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h
index 5098fd0d8c28..44b8fac5661e 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.h
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.h
@@ -148,12 +148,6 @@ class TORCH_API OperatorEntry final {
const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; }
- // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic
- void setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func);
-
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
template
void assertSignatureIsCorrect() {
@@ -189,12 +183,6 @@ class TORCH_API OperatorEntry final {
std::array(DispatchKey::NumDispatchKeys)> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
- // This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic
- c10::optional manuallyBoxedKernel_;
-
// kernels_ stores all registered kernels for the corresponding dispatch key
// and catchAllKernels_ stores the catch-all kernels.
// If an operator library gets loaded that overwrites an already existing kernel,
diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h
index 720d274ec5b2..624ded76ffda 100644
--- a/aten/src/ATen/core/function_schema.h
+++ b/aten/src/ATen/core/function_schema.h
@@ -107,7 +107,7 @@ struct Argument {
c10::optional N_;
c10::optional default_value_;
- // is this only specifyable as a keyword argument?
+ // is this only specifiable as a keyword argument?
bool kwarg_only_;
c10::optional alias_info_;
};
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 8065300f0b32..f99dc3c07058 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -17,6 +17,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
+ _(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
@@ -284,6 +285,9 @@ namespace c10 {
_(aten, zero_) \
_(aten, fill_) \
_(aten, masked_fill_) \
+ _(cuda, _set_device) \
+ _(cuda, set_stream) \
+ _(cuda, _current_device) \
_(aten, swapaxes) \
_(aten, swapaxes_) \
_(aten, swapdims) \
@@ -383,6 +387,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
+ _(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
@@ -453,6 +458,7 @@ struct TORCH_API Symbol {
// (and if it's not, you should add it to the built-ins list above.)
static Symbol attr(const std::string & s);
static Symbol aten(const std::string & s);
+ static Symbol cuda(const std::string & s);
static Symbol onnx(const std::string & s);
static Symbol prim(const std::string & s);
static Symbol user(const std::string & s);
@@ -463,6 +469,7 @@ struct TORCH_API Symbol {
bool is_attr() const;
bool is_aten() const;
+ bool is_cuda() const;
bool is_prim() const;
bool is_onnx() const;
bool is_user() const;
@@ -523,6 +530,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
+inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); }
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
@@ -531,6 +539,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
+inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index c05e7313fa63..1223577c59c6 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -125,7 +125,7 @@ TypePtr IValue::type() const {
void IValue::visit(const std::function& visitor) const {
if (visitor(*this)) {
- // Short cut.
+ // Shortcut
return;
}
switch (this->tag) {
@@ -265,7 +265,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) {
TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr);
TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr);
return lhs.tag == rhs.tag &&
- lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
+ lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
}
IValue IValue::equals(const IValue& rhs) const {
@@ -325,17 +325,17 @@ size_t IValue::hash(const IValue& v) {
case Tag::None:
return 0;
case Tag::Bool:
- return c10::get_hash(v.payload.as_bool);
+ return c10::get_hash(v.payload.u.as_bool);
case Tag::Double:
- return c10::get_hash(v.payload.as_double);
+ return c10::get_hash(v.payload.u.as_double);
case Tag::Tensor:
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
// the tensor to emulate it
- return c10::get_hash(v.payload.as_int);
+ return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
case Tag::Storage:
- return c10::get_hash(v.payload.as_int);
+ return c10::get_hash(v.payload.u.as_int);
case Tag::Int:
- return c10::get_hash(v.payload.as_int);
+ return c10::get_hash(v.payload.u.as_int);
case Tag::String:
return c10::get_hash(v.toStringRef());
case Tag::Tuple:
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 4a7e15c4008b..ca68a8df46e1 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -131,10 +131,15 @@ struct Capsule {
// they are marked `@private`, which hides them on the doxygen documentation for
// this page.
-/// IValue (Interpreter Value) is a tagged union over the types supported by the
-/// TorchScript interpreter. IValues contain their values as an
-/// `IValue::Payload`, which holds primitive types (`int64_t`, `bool`, `double`,
-/// `Device`), as values and all other types as a `c10::intrusive_ptr`.
+/// IValue (Interpreter Value) is a tagged union over the types
+/// supported by the TorchScript interpreter. IValues contain their
+/// values as an `IValue::Payload`, which holds primitive types
+/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values,
+/// and all other types as a `c10::intrusive_ptr`. In order to
+/// optimize performance of the destructor and related operations by
+/// making the `Tensor` and `c10::intrusive_ptr` paths generate the
+/// same code, we represent a null `c10::intrusive_ptr` as
+/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`.
///
/// IValues are used as inputs to and outputs from the TorchScript interpreter.
/// To retrieve the value contained within an IValue, use the `.toX()` methods,
@@ -160,27 +165,35 @@ struct Capsule {
struct TORCH_API IValue final {
IValue(const IValue& rhs)
: IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) {
- if (is_intrusive_ptr) {
- c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
+ if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+ c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
}
}
- IValue(IValue&& rhs) noexcept : IValue() {
- swap(rhs);
+
+ IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
+ moveFrom(std::move(rhs));
}
+
/// @private [doxygen private]
~IValue() {
- if (is_intrusive_ptr) {
- c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
- }
+ destroy();
}
- IValue& operator=(IValue&& rhs) & noexcept {
- IValue(std::move(rhs)).swap(*this); // this also sets rhs to None
+
+ C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
+ if (&rhs == this) {
+ return *this;
+ }
+
+ destroy();
+ moveFrom(std::move(rhs));
return *this;
}
+
IValue& operator=(IValue const& rhs) & {
IValue(rhs).swap(*this);
return *this;
}
+
void dump() const;
/**
@@ -260,6 +273,13 @@ struct TORCH_API IValue final {
return false;
}
+ // Tensors should be compared based on internal storage
+ if (this->isTensor()) {
+ const auto& thisTensor = this->toTensor();
+ const auto& rhsTensor = rhs.toTensor();
+ return thisTensor.is_alias_of(rhsTensor);
+ }
+
if (!this->is_intrusive_ptr) {
// Primitive types don't alias anything
return false;
@@ -267,29 +287,49 @@ struct TORCH_API IValue final {
AT_ASSERT(rhs.is_intrusive_ptr);
- // Tensors should be compared based on internal storage
- if (this->isTensor()) {
- const auto thisTensor = this->toTensor();
- const auto rhsTensor = rhs.toTensor();
- return thisTensor.is_alias_of(rhsTensor);
- }
-
// Other types can be compared by their ptr value
- return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
+ return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
}
/// @private [doxygen private]
size_t use_count() const noexcept {
+ if (isTensor()) {
+ return payload.as_tensor.use_count();
+ }
+
if (!is_intrusive_ptr) {
return 1;
}
- return c10::raw::intrusive_ptr::use_count(payload.as_intrusive_ptr);
+ if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
+ return 0;
+ }
+ return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr);
}
/// @private [doxygen private]
void swap(IValue& rhs) noexcept {
- std::swap(payload, rhs.payload);
+ if (isTensor() && rhs.isTensor()) {
+ std::swap(payload.as_tensor, rhs.payload.as_tensor);
+ } else if (isTensor()) {
+ at::Tensor t = std::move(payload.as_tensor);
+ // As far as I can tell, omitting the usual explicit destructor call
+ // is not UB in and of itself, and it's a slight perf win. The
+ // destructor is a no-op, because the moved-from Tensor is
+ // effectively an intrusive_ptr in the null state, so we don't need
+ // the behavior for correctness reasons either. Leaving this
+ // explanatory comment, including commented-out destructor call, to
+ // make this abundantly clear.
+ //
+ // payload.as_tensor.~Tensor();
+ payload.u = rhs.payload.u;
+ new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
+ } else if (rhs.isTensor()) {
+ rhs.swap(*this);
+ return;
+ } else {
+ std::swap(payload.u, rhs.payload.u);
+ }
std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
std::swap(tag, rhs.tag);
}
@@ -298,21 +338,17 @@ struct TORCH_API IValue final {
// While some of these accessors could be generated through templates,
// we prefer to write them manually for clarity
- IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
- // Note: the undefined tensor is not refcounted, so while it
- // is tagged as a tensor, is_intrusive_ptr is set to false.
- // This is not an optional optimization: our incref call
- // *will not* do the right thing when called on an
- // undefined tensor.
- payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl();
+ IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) {
+ new (&payload.as_tensor) at::Tensor(std::move(t));
}
bool isTensor() const {
return Tag::Tensor == tag;
}
at::Tensor toTensor() &&;
- at::Tensor toTensor() const&;
+ at::Tensor& toTensor() &;
+ const at::Tensor& toTensor() const&;
at::TensorImpl* unsafeToTensorImpl() const {
- return static_cast(payload.as_intrusive_ptr);
+ return payload.as_tensor.unsafeGetTensorImpl();
}
IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast(s)) {
@@ -321,7 +357,7 @@ struct TORCH_API IValue final {
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined tensor.
- payload.as_intrusive_ptr = s.unsafeReleaseStorageImpl();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl());
}
bool isStorage() const {
return Tag::Storage == tag;
@@ -341,7 +377,7 @@ struct TORCH_API IValue final {
: tag(Tag::Blob), is_intrusive_ptr(true) {
// TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
// and store it as a Tensor instead.
- payload.as_intrusive_ptr = blob.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
}
/// @private [doxygen private]
@@ -397,14 +433,14 @@ struct TORCH_API IValue final {
// Double
IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) {
- payload.as_double = d;
+ payload.u.as_double = d;
}
bool isDouble() const {
return Tag::Double == tag;
}
double toDouble() const {
AT_ASSERT(isDouble());
- return payload.as_double;
+ return payload.u.as_double;
}
// Future
@@ -433,7 +469,7 @@ struct TORCH_API IValue final {
// Int
IValue(int64_t i) : tag(Tag::Int), is_intrusive_ptr(false) {
- payload.as_int = i;
+ payload.u.as_int = i;
}
// allow you to pass literals (3, 4) without ambiguity
@@ -445,7 +481,7 @@ struct TORCH_API IValue final {
int64_t toInt() const {
AT_ASSERT(isInt());
- return payload.as_int;
+ return payload.u.as_int;
}
// Bool
@@ -454,9 +490,9 @@ struct TORCH_API IValue final {
// Initializing entire payload stops valgrind's from reporting
// "jump or move depends on uninitialised value" in IValue copy constructor
// See https://github.com/pytorch/pytorch/issues/37117
- payload.as_int = b;
+ payload.u.as_int = b;
#else
- payload.as_bool = b;
+ payload.u.as_bool = b;
#endif
}
bool isBool() const {
@@ -464,7 +500,7 @@ struct TORCH_API IValue final {
}
bool toBool() const {
AT_ASSERT(isBool());
- return payload.as_bool;
+ return payload.u.as_bool;
}
// IntList
@@ -580,7 +616,7 @@ struct TORCH_API IValue final {
c10::intrusive_ptr toEnumHolder() const&;
// None
- IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
+ IValue() : tag(Tag::None), is_intrusive_ptr(false) {}
bool isNone() const {
return Tag::None == tag;
}
@@ -616,21 +652,21 @@ struct TORCH_API IValue final {
// Device
IValue(c10::Device d) : tag(Tag::Device), is_intrusive_ptr(false) {
- payload.as_device.type = d.type();
- payload.as_device.index = d.index();
+ payload.u.as_device.type = d.type();
+ payload.u.as_device.index = d.index();
}
bool isDevice() const {
return Tag::Device == tag;
}
c10::Device toDevice() const {
AT_ASSERT(isDevice());
- return c10::Device(payload.as_device.type, payload.as_device.index);
+ return c10::Device(payload.u.as_device.type, payload.u.as_device.index);
}
//Stream
IValue(c10::Stream stream)
: tag(Tag::Stream), is_intrusive_ptr(false) {
- payload.as_int = stream.pack();
+ payload.u.as_int = stream.pack();
}
c10::Stream toStream() &&;
c10::Stream toStream() const &;
@@ -659,7 +695,7 @@ struct TORCH_API IValue final {
// QScheme
IValue(at::QScheme qscheme) : tag(Tag::Int), is_intrusive_ptr(false) {
- payload.as_int = static_cast(qscheme);
+ payload.u.as_int = static_cast(qscheme);
}
at::QScheme toQScheme() const {
@@ -680,7 +716,7 @@ struct TORCH_API IValue final {
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined generator.
- payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl());
}
bool isGenerator() const {
return Tag::Generator == tag;
@@ -749,14 +785,19 @@ struct TORCH_API IValue final {
const IValue& v);
bool isPtrType() const {
- return is_intrusive_ptr;
+ return (isTensor() && payload.as_tensor.defined()) || is_intrusive_ptr;
}
/// @private [doxygen private]
const void* internalToPointer() const {
TORCH_INTERNAL_ASSERT(
isPtrType(), "Can only call internalToPointer() for pointer types");
- return payload.as_intrusive_ptr;
+ if (isTensor()) {
+ return payload.as_tensor.unsafeGetTensorImpl();
+ } else {
+ return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()
+ ? payload.u.as_intrusive_ptr : nullptr;
+ }
}
TypePtr type() const;
@@ -770,7 +811,7 @@ struct TORCH_API IValue final {
}
// If it is not a Tensor, then two mutable IValues alias each other only
// if they are the same pointer.
- return val.payload.as_int;
+ return val.payload.u.as_int;
}
};
@@ -800,6 +841,10 @@ struct TORCH_API IValue final {
IValue deepcopy(HashAliasedIValueMap& memo) const;
private:
+ static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) {
+ return p ? p : static_cast(c10::UndefinedTensorImpl::singleton());
+ }
+
static bool ptrEqual(const IValue& lhs, const IValue& rhs);
// NOTE: IValue tags are intentionally private. In the future we may encode
// this value different (e.g. using NaN boxing), and this would make it more
@@ -822,24 +867,77 @@ struct TORCH_API IValue final {
class NullType = c10::detail::intrusive_target_default_null_type>
c10::intrusive_ptr toIntrusivePtr() const;
- void clearToNone() {
- payload.as_int = 0;
+ void destroy() {
+ // We carefully construct this call to both 1) avoid UB by using
+ // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable
+ // the compiler to generate the same code for each case. It is
+ // surprisingly difficult to get this right.
+ if (isTensor() || is_intrusive_ptr) {
+ c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr;
+ c10::intrusive_ptr::reclaim(p);
+ // No need to make this destructor call!
+ // payload.as_tensor.~Tensor();
+ }
+ }
+
+ C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
+ if (rhs.isTensor()) {
+ new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
+ // As far as I can tell, omitting the usual explicit destructor call
+ // is not UB in and of itself, and it's a slight perf win. The
+ // destructor is a no-op, because the moved-from Tensor is
+ // effectively an intrusive_ptr in the null state, so we don't need
+ // the behavior for correctness reasons either. Leaving this
+ // explanatory comment, including commented-out destructor call, to
+ // make this abundantly clear.
+ //
+ // rhs.payload.as_tensor.~Tensor();
+ } else {
+ payload.u = rhs.payload.u;
+ }
+ tag = rhs.tag;
+ is_intrusive_ptr = rhs.is_intrusive_ptr;
+ rhs.clearToNone();
+ }
+
+ void clearToNone() noexcept {
+ payload.u.as_int = 0;
tag = Tag::None;
is_intrusive_ptr = false;
}
union Payload {
- int64_t as_int;
- double as_double;
- bool as_bool;
- c10::intrusive_ptr_target* as_intrusive_ptr;
- struct {
- DeviceType type;
- DeviceIndex index;
- } as_device;
+ // We use a nested union here so that we can make the copy easy
+ // and efficient in the non-tensor (i.e., trivially copyable)
+ // case. Specifically, we do not have to do a switch-on-tag to
+ // figure out which union member to assign; we can just use
+ // TriviallyCopyablePayload::operator=.
+ union TriviallyCopyablePayload {
+ TriviallyCopyablePayload() : as_int(0) {}
+ int64_t as_int;
+ double as_double;
+ bool as_bool;
+ // Invariant: never nullptr; null state is represented as
+ // c10::UndefinedTensorImpl::singleton() for consistency of
+ // representation with Tensor.
+ c10::intrusive_ptr_target* as_intrusive_ptr;
+ struct {
+ DeviceType type;
+ DeviceIndex index;
+ } as_device;
+ } u;
+ at::Tensor as_tensor;
+ Payload() : u() {}
+ ~Payload() {}
};
- IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {}
+ IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) {
+ if (isTensor()) {
+ new (&payload.as_tensor) at::Tensor(p.as_tensor);
+ } else {
+ payload.u = p.u;
+ }
+ }
Payload payload;
Tag tag;
@@ -848,29 +946,36 @@ struct TORCH_API IValue final {
};
struct TORCH_API WeakIValue final {
- WeakIValue() : payload{0}, tag(IValue::Tag::None), is_intrusive_ptr(false) {}
+ WeakIValue() : tag(IValue::Tag::None), is_intrusive_ptr(false) {}
WeakIValue(const WeakIValue& rhs)
: payload(rhs.payload),
tag(rhs.tag),
is_intrusive_ptr(rhs.is_intrusive_ptr) {
- if (is_intrusive_ptr) {
+ if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
}
}
WeakIValue(const IValue& rhs)
- : payload(rhs.payload),
- tag(rhs.tag),
+ : tag(rhs.tag),
is_intrusive_ptr(rhs.is_intrusive_ptr) {
+ if (rhs.isTensor()) {
+ payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
+ is_intrusive_ptr = true;
+ } else {
+ payload = rhs.payload.u;
+ }
if (is_intrusive_ptr) {
- c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
+ if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+ c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
+ }
}
}
WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() {
swap(rhs);
}
~WeakIValue() {
- if (is_intrusive_ptr) {
+ if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr);
}
}
@@ -895,17 +1000,33 @@ struct TORCH_API WeakIValue final {
IValue lock() const {
if (!is_intrusive_ptr) {
- return IValue(payload, tag, false);
+ IValue::Payload newPayload;
+ newPayload.u = payload;
+ return IValue(newPayload, tag, false);
}
- auto temp = c10::weak_intrusive_ptr::reclaim(
- payload.as_intrusive_ptr);
- IValue::Payload pl;
- pl.as_intrusive_ptr = temp.lock().release();
- temp.release();
- if (!pl.as_intrusive_ptr) {
- return IValue();
+ if (IValue::Tag::Tensor == tag) {
+ auto temp = c10::weak_intrusive_ptr::reclaim(
+ static_cast(payload.as_intrusive_ptr));
+ c10::intrusive_ptr ip(temp.lock());
+ temp.release();
+ if (!ip) {
+ return IValue();
+ } else {
+ return IValue(at::Tensor(std::move(ip)));
+ }
} else {
- return IValue(pl, tag, true);
+ auto temp = c10::weak_intrusive_ptr::reclaim(
+ payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+ ? nullptr
+ : payload.as_intrusive_ptr);
+ IValue::Payload pl;
+ pl.u.as_intrusive_ptr = temp.lock().release();
+ temp.release();
+ if (!pl.u.as_intrusive_ptr) {
+ return IValue();
+ } else {
+ return IValue(pl, tag, true);
+ }
}
}
@@ -913,7 +1034,7 @@ struct TORCH_API WeakIValue final {
if (!is_intrusive_ptr) {
return 1;
}
- auto temp = c10::weak_intrusive_ptr::reclaim(
+ auto temp = c10::weak_intrusive_ptr::reclaim(
payload.as_intrusive_ptr);
size_t result = temp.use_count();
temp.release();
@@ -924,7 +1045,7 @@ struct TORCH_API WeakIValue final {
if (!is_intrusive_ptr) {
return 1;
}
- auto temp = c10::weak_intrusive_ptr::reclaim(
+ auto temp = c10::weak_intrusive_ptr::reclaim(
payload.as_intrusive_ptr);
size_t result = temp.weak_use_count();
temp.release();
@@ -935,7 +1056,8 @@ struct TORCH_API WeakIValue final {
}
private:
- IValue::Payload payload;
+ using Payload = IValue::Payload::TriviallyCopyablePayload;
+ Payload payload;
IValue::Tag tag;
bool is_intrusive_ptr;
};
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index 89c8e669c138..b96f4b834989 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -48,14 +48,18 @@ struct tagged_capsule {
template
c10::intrusive_ptr IValue::moveToIntrusivePtr() {
auto t = c10::intrusive_ptr::reclaim(
- static_cast(payload.as_intrusive_ptr));
+ payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+ ? NullType::singleton()
+ : static_cast(payload.u.as_intrusive_ptr));
clearToNone();
return t;
}
template
c10::intrusive_ptr IValue::toIntrusivePtr() const {
auto r = c10::intrusive_ptr::reclaim(
- static_cast(payload.as_intrusive_ptr));
+ payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+ ? NullType::singleton()
+ : static_cast(payload.u.as_intrusive_ptr));
auto p = r;
r.release();
return p;
@@ -131,12 +135,26 @@ inline c10::intrusive_ptr IValue::toEnumHolder() const& {
}
inline at::Tensor IValue::toTensor() && {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
- return at::Tensor(
- moveToIntrusivePtr());
+ auto result = std::move(payload.as_tensor);
+ // As far as I can tell, omitting the usual explicit destructor call
+ // is not UB in and of itself, and it's a slight perf win. The
+ // destructor is a no-op, because the moved-from Tensor is
+ // effectively an intrusive_ptr in the null state, so we don't need
+ // the behavior for correctness reasons either. Leaving this
+ // explanatory comment, including commented-out destructor call, to
+ // make this abundantly clear.
+ //
+ // payload.as_tensor.~Tensor();
+ clearToNone();
+ return result;
}
-inline at::Tensor IValue::toTensor() const& {
+inline at::Tensor& IValue::toTensor() & {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
- return at::Tensor(toIntrusivePtr());
+ return payload.as_tensor;
+}
+inline const at::Tensor& IValue::toTensor() const& {
+ AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
+ return payload.as_tensor;
}
inline c10::Storage IValue::toStorage() && {
AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
@@ -148,10 +166,10 @@ inline c10::Storage IValue::toStorage() const& {
return c10::Storage(toIntrusivePtr());
}
inline c10::Stream IValue::toStream() && {
- return c10::Stream::unpack(payload.as_int);
+ return c10::Stream::unpack(payload.u.as_int);
}
inline c10::Stream IValue::toStream() const& {
- return c10::Stream::unpack(payload.as_int);
+ return c10::Stream::unpack(payload.u.as_int);
}
inline c10::intrusive_ptr IValue::toBlob() && {
AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
@@ -713,7 +731,8 @@ using _guarded_unsigned_long = std::conditional_t<
inline const ivalue::Object& IValue::toObjectRef() const {
AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
- return *static_cast(payload.as_intrusive_ptr);
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
+ return *static_cast(payload.u.as_intrusive_ptr);
}
// note: when adding a DEFINE_TO case here you should also add a
@@ -729,6 +748,7 @@ inline const ivalue::Object& IValue::toObjectRef() const {
inline type IValue::to() const& { \
return this->method_name(); \
}
+
DEFINE_TO(at::Tensor, toTensor)
DEFINE_TO(at::Storage, toStorage)
DEFINE_TO(c10::Stream, toStream)
@@ -980,8 +1000,11 @@ inline c10::List IValue::toIntList() const& {
}
inline std::vector IValue::toIntVector() const {
AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toIntVector on null intrusive_ptr IValue");
return createVectorFromList(
- static_cast(payload.as_intrusive_ptr));
+ static_cast(payload.u.as_intrusive_ptr));
}
inline c10::List IValue::toDoubleList() && {
AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
@@ -993,8 +1016,11 @@ inline c10::List IValue::toDoubleList() const& {
}
inline std::vector IValue::toDoubleVector() const {
AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toDoubleVector on null intrusive_ptr IValue");
return createVectorFromList(
- static_cast(payload.as_intrusive_ptr));
+ static_cast(payload.u.as_intrusive_ptr));
}
inline c10::List IValue::toBoolList() && {
AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
@@ -1014,8 +1040,11 @@ inline c10::List IValue::toTensorList() const& {
}
inline std::vector IValue::toTensorVector() const {
AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toTensorVector on null intrusive_ptr IValue");
return createVectorFromList(
- static_cast(payload.as_intrusive_ptr));
+ static_cast(payload.u.as_intrusive_ptr));
}
inline c10::List IValue::toList() && {
AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
@@ -1027,7 +1056,10 @@ inline c10::List IValue::toList() const& {
}
inline c10::ArrayRef IValue::toListRef() const {
AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
- return static_cast(payload.as_intrusive_ptr)
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toListRef on null intrusive_ptr IValue");
+ return static_cast(payload.u.as_intrusive_ptr)
->list;
}
inline c10::Dict IValue::toGenericDict() && {
@@ -1049,7 +1081,7 @@ inline c10::intrusive_ptr IValue::toTuple() const& {
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Tuple), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
template <
typename... Args,
@@ -1065,14 +1097,14 @@ inline IValue::IValue(const std::tuple& t)
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::String), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(std::string v)
: IValue(ivalue::ConstantString::create(std::move(v))) {}
inline IValue::IValue(c10::impl::GenericList v)
: tag(Tag::GenericList), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.impl_.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
}
template >
@@ -1104,7 +1136,7 @@ inline IValue::IValue(std::array v) : IValue(c10::List()) {
inline IValue::IValue(c10::impl::GenericDict v)
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.impl_.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
}
template
inline IValue::IValue(c10::Dict v)
@@ -1131,17 +1163,17 @@ inline IValue::IValue(c10::nullopt_t) : IValue() {}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Object), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::PyObject), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(c10::intrusive_ptr