diff --git a/.circleci/config.yml b/.circleci/config.yml
index 0716e516518b..d19c08b2b0b6 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -11,6 +11,9 @@ parameters:
run_binary_tests:
type: boolean
default: false
+ run_build:
+ type: boolean
+ default: true
docker_config_defaults: &docker_config_defaults
user: jenkins
@@ -9762,6 +9765,7 @@ workflows:
only:
- postnightly
executor: windows-with-nvidia-gpu
+ when: << pipeline.parameters.run_build >>
ecr_gc:
triggers:
- schedule:
diff --git a/.circleci/generate_config_yml.py b/.circleci/generate_config_yml.py
index f1af924bd3e2..a836d2e510a6 100755
--- a/.circleci/generate_config_yml.py
+++ b/.circleci/generate_config_yml.py
@@ -112,7 +112,10 @@ def gen_build_workflows_tree():
"when": r"<< pipeline.parameters.run_binary_tests >>",
"jobs": [f() for f in binary_build_functions],
},
- "build": {"jobs": [f() for f in build_workflows_functions]},
+ "build": {
+ "when": r"<< pipeline.parameters.run_build >>",
+ "jobs": [f() for f in build_workflows_functions]
+ },
}
}
diff --git a/.circleci/verbatim-sources/header-section.yml b/.circleci/verbatim-sources/header-section.yml
index 26205a0cccba..43d4c94ee5ed 100644
--- a/.circleci/verbatim-sources/header-section.yml
+++ b/.circleci/verbatim-sources/header-section.yml
@@ -11,6 +11,9 @@ parameters:
run_binary_tests:
type: boolean
default: false
+ run_build:
+ type: boolean
+ default: true
docker_config_defaults: &docker_config_defaults
user: jenkins
diff --git a/.github/pytorch-circleci-labels.yml b/.github/pytorch-circleci-labels.yml
index ccdf2e876af1..3a9eeca0abcc 100644
--- a/.github/pytorch-circleci-labels.yml
+++ b/.github/pytorch-circleci-labels.yml
@@ -9,3 +9,5 @@ labels_to_circle_params:
- release/.*
tags:
- v[0-9]+(\.[0-9]+)*-rc[0-9]+
+ set_to_false:
+ - run_build
diff --git a/.jenkins/pytorch/README.md b/.jenkins/pytorch/README.md
index ea6c6dd40f68..9fd68ecf7f15 100644
--- a/.jenkins/pytorch/README.md
+++ b/.jenkins/pytorch/README.md
@@ -10,9 +10,9 @@ it is very easy to run these tests yourself:
``registry.pytorch.org/pytorch/pytorch-$BUILD_ENVIRONMENT:$DOCKER_VERSION``,
where ``$BUILD_ENVIRONMENT`` is one of the build environments
enumerated in
- [pytorch-dockerfiles](https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh)
+ [pytorch-dockerfiles](https://github.com/pytorch/pytorch/blob/master/.circleci/docker/build.sh). The dockerfile used by jenkins can be found under the `.circle` [directory](https://github.com/pytorch/pytorch/blob/master/.circleci/docker)
-2. Run ``docker -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and
+2. Run ``docker run -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and
run one of the scripts in this directory.
The Docker images are designed so that any "reasonable" build commands
@@ -38,5 +38,5 @@ mechanisms we use:
build scripts.
- We reroute well known paths like `/usr/bin/gcc` to alternate
- implementations with `update-alternatives, instead of setting
+ implementations with `update-alternatives`, instead of setting
`CC` and `CXX` in our implementations.
diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh
index 17e7e9fa3445..47d13f2908d0 100755
--- a/.jenkins/pytorch/codegen-test.sh
+++ b/.jenkins/pytorch/codegen-test.sh
@@ -48,13 +48,6 @@ python -m tools.autograd.gen_autograd \
"$OUT"/autograd \
tools/autograd
-# unboxing_wrappers codegen (called by torch codegen but can run independently)
-mkdir -p "$OUT"/unboxing_wrappers
-python -m tools.jit.gen_unboxing_wrappers \
- "$OUT"/torch/share/ATen/Declarations.yaml \
- "$OUT"/unboxing_wrappers \
- tools/jit/templates
-
# annotated_fn_args codegen (called by torch codegen but can run independently)
mkdir -p "$OUT"/annotated_fn_args
python -m tools.autograd.gen_annotated_fn_args \
diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh
index 0c34ddcc6179..24ec02c76df5 100755
--- a/.jenkins/pytorch/macos-test.sh
+++ b/.jenkins/pytorch/macos-test.sh
@@ -9,11 +9,6 @@ pip install -q hypothesis "librosa>=0.6.2" "numba<=0.49.1" psutil
# TODO move this to docker
pip install unittest-xml-reporting pytest
-# faulthandler become built-in since 3.3
-if [[ ! $(python -c "import sys; print(int(sys.version_info >= (3, 3)))") == "1" ]]; then
- pip install -q faulthandler
-fi
-
if [ -z "${IN_CI}" ]; then
rm -rf ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch*
fi
diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
index a052a1b67d59..ed6482890993 100644
--- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
+++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat
@@ -41,8 +41,6 @@ popd
:: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest coverage
if %errorlevel% neq 0 ( exit /b %errorlevel% )
-:: No need to install faulthandler since we only test Python >= 3.6 on Windows
-:: faulthandler is builtin since Python 3.3
set DISTUTILS_USE_SDK=1
diff --git a/BUILD.bazel b/BUILD.bazel
index b3faea487965..2b4636d850c9 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -193,9 +193,6 @@ libtorch_cpp_generated_sources = [
"torch/csrc/autograd/generated/Functions.h",
"torch/csrc/autograd/generated/Functions.cpp",
"torch/csrc/autograd/generated/variable_factories.h",
- "torch/csrc/jit/generated/generated_unboxing_wrappers_0.cpp",
- "torch/csrc/jit/generated/generated_unboxing_wrappers_1.cpp",
- "torch/csrc/jit/generated/generated_unboxing_wrappers_2.cpp",
]
libtorch_python_generated_sources = [
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ba862b5a4d5f..3df73f8a3041 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -173,6 +173,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF)
cmake_dependent_option(
USE_NCCL "Use NCCL" ON
"USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
+cmake_dependent_option(USE_RCCL "Use RCCL" ON
+ USE_NCCL OFF)
cmake_dependent_option(
USE_STATIC_NCCL "Use static NCCL" OFF
"USE_NCCL" OFF)
@@ -316,7 +318,7 @@ set(OP_DEPENDENCY "" CACHE STRING
# symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk
# https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu
if(LINUX)
- set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed")
+ set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed ${CMAKE_SHARED_LINKER_FLAGS}")
endif()
if(MSVC)
diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml
index a83bf223bdaf..abdd9a8d986a 100644
--- a/android/test_app/app/src/main/AndroidManifest.xml
+++ b/android/test_app/app/src/main/AndroidManifest.xml
@@ -18,4 +18,10 @@
+
+
+
+
diff --git a/aten/conda/meta.yaml b/aten/conda/meta.yaml
index d8096fc73a0f..a502690a5447 100644
--- a/aten/conda/meta.yaml
+++ b/aten/conda/meta.yaml
@@ -24,7 +24,7 @@ requirements:
- mkl # [not osx]
about:
- home: https://github.com/zdevito/ATen
+ home: https://github.com/pytorch/pytorch
license: BSD
summary: A TENsor library for C++14
diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp
index 9bdec2dce77e..2cd7cac4e71b 100644
--- a/aten/src/ATen/BatchingRegistrations.cpp
+++ b/aten/src/ATen/BatchingRegistrations.cpp
@@ -1015,7 +1015,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("_add_batch_dim", native::_add_batch_dim);
m.impl("_remove_batch_dim", native::_remove_batch_dim);
- m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule);
+ m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index fd3c95f2573b..6fedef185b21 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -72,7 +72,7 @@ file(GLOB metal_h "metal/*.h")
file(GLOB metal_cpp "metal/*.cpp")
file(GLOB_RECURSE native_metal_h "native/metal/*.h")
file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm")
-file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp")
+file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm" "native/metal/*.cpp")
EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs})
file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h")
file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp")
diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp
index bfa4a2a8f72f..ff4a2f1c61e2 100644
--- a/aten/src/ATen/CPUGeneratorImpl.cpp
+++ b/aten/src/ATen/CPUGeneratorImpl.cpp
@@ -1,4 +1,6 @@
#include
+#include
+#include
#include
#include
@@ -6,6 +8,42 @@ namespace at {
namespace detail {
+/**
+ * CPUGeneratorImplStateLegacy is a POD class needed for memcpys
+ * in torch.get_rng_state() and torch.set_rng_state().
+ * It is a legacy class and even though it is replaced with
+ * at::CPUGeneratorImpl, we need this class and some of its fields
+ * to support backward compatibility on loading checkpoints.
+ */
+struct CPUGeneratorImplStateLegacy {
+ /* The initial seed. */
+ uint64_t the_initial_seed;
+ int left; /* = 1; */
+ int seeded; /* = 0; */
+ uint64_t next;
+ uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */
+
+ /********************************/
+
+ /* For normal distribution */
+ double normal_x;
+ double normal_y;
+ double normal_rho;
+ int normal_is_valid; /* = 0; */
+};
+
+/**
+ * CPUGeneratorImplState is a POD class containing
+ * new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
+ * as a helper for torch.get_rng_state() and torch.set_rng_state()
+ * functions.
+ */
+struct CPUGeneratorImplState {
+ CPUGeneratorImplStateLegacy legacy_pod;
+ float next_float_normal_sample;
+ bool is_next_float_normal_sample_valid;
+};
+
/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
@@ -75,6 +113,128 @@ uint64_t CPUGeneratorImpl::seed() {
return random;
}
+/**
+ * Sets the internal state of CPUGeneratorImpl. The new internal state
+ * must be a strided CPU byte tensor and of the same size as either
+ * CPUGeneratorImplStateLegacy (for legacy CPU generator state) or
+ * CPUGeneratorImplState (for new state).
+ *
+ * FIXME: Remove support of the legacy state in the future?
+ */
+void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
+ using detail::CPUGeneratorImplState;
+ using detail::CPUGeneratorImplStateLegacy;
+
+ static_assert(std::is_pod::value, "CPUGeneratorImplStateLegacy is not a PODType");
+ static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType");
+
+ static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
+ static const size_t size_current = sizeof(CPUGeneratorImplState);
+ static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size");
+
+ detail::check_rng_state(new_state);
+
+ at::mt19937 engine;
+ auto float_normal_sample = c10::optional();
+ auto double_normal_sample = c10::optional();
+
+ // Construct the state of at::CPUGeneratorImpl based on input byte tensor size.
+ CPUGeneratorImplStateLegacy* legacy_pod;
+ auto new_state_size = new_state.numel();
+ if (new_state_size == size_legacy) {
+ legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data();
+ // Note that in CPUGeneratorImplStateLegacy, we didn't have float version
+ // of normal sample and hence we leave the c10::optional as is
+
+ // Update next_double_normal_sample.
+ // Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y)
+ // and a rho value (normal_rho). These three values were redundant and in the new
+ // DistributionsHelper.h, we store the actual extra normal sample, rather than three
+ // intermediate values.
+ if (legacy_pod->normal_is_valid) {
+ auto r = legacy_pod->normal_rho;
+ auto theta = 2.0 * M_PI * legacy_pod->normal_x;
+ // we return the sin version of the normal sample when in caching mode
+ double_normal_sample = c10::optional(r * ::sin(theta));
+ }
+ } else if (new_state_size == size_current) {
+ auto rng_state = (CPUGeneratorImplState*)new_state.data();
+ legacy_pod = &rng_state->legacy_pod;
+ // update next_float_normal_sample
+ if (rng_state->is_next_float_normal_sample_valid) {
+ float_normal_sample = c10::optional(rng_state->next_float_normal_sample);
+ }
+
+ // Update next_double_normal_sample.
+ // Note that in getRNGState, we now return the actual normal sample in normal_y
+ // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho
+ // are squashed to 0.0.
+ if (legacy_pod->normal_is_valid) {
+ double_normal_sample = c10::optional(legacy_pod->normal_y);
+ }
+ } else {
+ AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy,
+ " or a CPUGeneratorImplState of size ", size_current,
+ " but found the input RNG state size to be ", new_state_size);
+ }
+
+ // construct engine_
+ // Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our
+ // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are
+ // doing a std::copy.
+ at::mt19937_data_pod rng_data;
+ std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin());
+ rng_data.seed_ = legacy_pod->the_initial_seed;
+ rng_data.left_ = legacy_pod->left;
+ rng_data.seeded_ = legacy_pod->seeded;
+ rng_data.next_ = static_cast(legacy_pod->next);
+ engine.set_data(rng_data);
+ TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state");
+ this->engine_ = engine;
+ this->next_float_normal_sample_ = float_normal_sample;
+ this->next_double_normal_sample_ = double_normal_sample;
+}
+
+/**
+ * Gets the current internal state of CPUGeneratorImpl. The internal
+ * state is returned as a CPU byte tensor.
+ */
+c10::intrusive_ptr CPUGeneratorImpl::get_state() const {
+ using detail::CPUGeneratorImplState;
+
+ static const size_t size = sizeof(CPUGeneratorImplState);
+ static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType");
+
+ auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
+ auto rng_state = state_tensor.data_ptr();
+
+ // accumulate generator data to be copied into byte tensor
+ auto accum_state = std::make_unique();
+ auto rng_data = this->engine_.data();
+ accum_state->legacy_pod.the_initial_seed = rng_data.seed_;
+ accum_state->legacy_pod.left = rng_data.left_;
+ accum_state->legacy_pod.seeded = rng_data.seeded_;
+ accum_state->legacy_pod.next = rng_data.next_;
+ std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state));
+ accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy
+ accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy
+ accum_state->legacy_pod.normal_is_valid = false;
+ accum_state->legacy_pod.normal_y = 0.0;
+ accum_state->next_float_normal_sample = 0.0f;
+ accum_state->is_next_float_normal_sample_valid = false;
+ if (this->next_double_normal_sample_) {
+ accum_state->legacy_pod.normal_is_valid = true;
+ accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_);
+ }
+ if (this->next_float_normal_sample_) {
+ accum_state->is_next_float_normal_sample_valid = true;
+ accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
+ }
+
+ memcpy(rng_state, accum_state.get(), size);
+ return state_tensor.getIntrusivePtr();
+}
+
/**
* Gets the DeviceType of CPUGeneratorImpl.
* Used for type checking during run time.
diff --git a/aten/src/ATen/CPUGeneratorImpl.h b/aten/src/ATen/CPUGeneratorImpl.h
index eceb338966fd..f8b43a04c73c 100644
--- a/aten/src/ATen/CPUGeneratorImpl.h
+++ b/aten/src/ATen/CPUGeneratorImpl.h
@@ -17,6 +17,8 @@ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
+ void set_state(const c10::TensorImpl& new_state) override;
+ c10::intrusive_ptr get_state() const override;
static DeviceType device_type();
uint32_t random();
uint64_t random64();
diff --git a/aten/src/ATen/CUDAGeneratorImpl.h b/aten/src/ATen/CUDAGeneratorImpl.h
index 9a9febd01f8e..1179a049aa08 100644
--- a/aten/src/ATen/CUDAGeneratorImpl.h
+++ b/aten/src/ATen/CUDAGeneratorImpl.h
@@ -129,8 +129,10 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
+ void set_state(const c10::TensorImpl& new_state) override;
+ c10::intrusive_ptr get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
- uint64_t philox_offset_per_thread();
+ uint64_t philox_offset_per_thread() const;
void capture_prologue(int64_t* offset_extragraph);
uint64_t capture_epilogue();
PhiloxCudaState philox_cuda_state(uint64_t increment);
diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h
index 41252609953f..341e20cab1f3 100644
--- a/aten/src/ATen/Dispatch.h
+++ b/aten/src/ATen/Dispatch.h
@@ -10,6 +10,9 @@
#include
#include
+#ifdef XPLAT_MOBILE_BUILD
+#include
+#else
namespace at {
/**
* The method should_include_kernel_dtype() returns true/false
@@ -25,6 +28,7 @@ inline constexpr bool should_include_kernel_dtype(
return true;
}
}
+#endif
/**
* In the Facebook internal build (using BUCK), this macro is enabled by
@@ -93,26 +97,6 @@ inline constexpr bool should_include_kernel_dtype(
return __VA_ARGS__(); \
}
-// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
-// should be removed once the bfloat16 bringup is complete on other platforms.
-// This is supposed to be used as a wrapper around the lambda function passed to
-// the dispatch macro and will conditionally dispatch ops with bfloat16 type
-// only on ROCm.
-#if !defined(__HIP_PLATFORM_HCC__)
-#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) \
- if (std::is_same::value) { \
- AT_ERROR( \
- #NAME, \
- " not implemented for '", \
- toString(at::ScalarType::BFloat16), \
- "'"); \
- } else { \
- return __VA_ARGS__(); \
- }
-#else
-#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) return __VA_ARGS__()
-#endif
-
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h
index 5063beeb08b0..8fa085245459 100644
--- a/aten/src/ATen/VmapTransforms.h
+++ b/aten/src/ATen/VmapTransforms.h
@@ -96,8 +96,17 @@ struct VmapPhysicalToLogicalMap;
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
-// bit of `levels` specifies the minimum number of nested vmaps we are in at
+// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
+// For example, given:
+// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
+//
+// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
+// than or equal to 3.
+// bitset: 010100
+// ^
+// |
+// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset levels)
: levels_(levels), tensor_(tensor) {
diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp
index dfb8e3ac0f32..9a2f34257c57 100644
--- a/aten/src/ATen/autocast_mode.cpp
+++ b/aten/src/ATen/autocast_mode.cpp
@@ -239,13 +239,9 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction::type::call);
-#define KERNEL_UNBOXED_ONLY(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
- m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
- &WrapFunction::type::call);
-
// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
-#define KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
- m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
+#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
+ m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction::type::call);
/*****************************************
@@ -367,20 +363,20 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional&, const c10::optional&, int64_t), fp32)
KERNEL(ADD_NS(dist), "dist", Tensor (const Tensor &, const Tensor &, Scalar), fp32)
KERNEL(ADD_NS(pdist), "pdist", Tensor (const Tensor &, double), fp32)
- KERNEL_UNBOXED_ONLY(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32)
+ KERNEL(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32)
KERNEL(ADD_NS(renorm), "renorm", Tensor (const Tensor &, Scalar, int64_t, Scalar), fp32)
// fp32_set_opt_dtype
KERNEL(ADD_NS(prod), "prod", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(prod), "prod.dim_int", Tensor (const Tensor &, int64_t, bool, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(softmax), "softmax.int", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(log_softmax), "log_softmax.int", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumprod), "cumprod", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumsum), "cumsum", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype)
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dtype", Tensor (const Tensor &, c10::optional, ScalarType), fp32_set_opt_dtype)
@@ -388,20 +384,20 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional), fp32_set_opt_dtype)
- KERNEL_UNBOXED_ONLY(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype)
+ KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype)
// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
// norm does not implicitly promote, but be aware when adding new ops to this policy.
- KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional, ScalarType), fp32_append_dtype)
- KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional, IntArrayRef, bool), Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_append_dtype)
- KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional, DimnameList, bool), Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_append_dtype)
+ KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional, ScalarType), fp32_append_dtype)
+ KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional, IntArrayRef, bool), Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_append_dtype)
+ KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional, DimnameList, bool), Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_append_dtype)
// promote
KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote)
KERNEL(ADD_NS(addcmul), "addcmul", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote)
KERNEL(ADD_NS(atan2), "atan2", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(bilinear), "bilinear", Tensor (const Tensor &, const Tensor &, const Tensor &, const c10::optional&), promote)
KERNEL(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
- KERNEL_UNBOXED_ONLY(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote)
+ KERNEL(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote)
KERNEL(ADD_NS(_cat), "_cat", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
diff --git a/aten/src/ATen/core/Generator.cpp b/aten/src/ATen/core/Generator.cpp
new file mode 100644
index 000000000000..800f8c7c88ec
--- /dev/null
+++ b/aten/src/ATen/core/Generator.cpp
@@ -0,0 +1,16 @@
+#include
+#include
+#include
+
+namespace at {
+
+void Generator::set_state(const at::Tensor& new_state) {
+ TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed");
+ this->impl_->set_state(*new_state.unsafeGetTensorImpl());
+}
+
+at::Tensor Generator::get_state() const {
+ return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
+}
+
+} // namespace at
diff --git a/aten/src/ATen/core/Generator.h b/aten/src/ATen/core/Generator.h
index de3f6e46f8f2..b5bbb2fe3c74 100644
--- a/aten/src/ATen/core/Generator.h
+++ b/aten/src/ATen/core/Generator.h
@@ -56,6 +56,8 @@
namespace at {
+class Tensor;
+
struct TORCH_API Generator {
Generator() {}
@@ -96,6 +98,12 @@ struct TORCH_API Generator {
uint64_t seed() { return impl_->seed(); }
+ // Implementation not inlined to prevent cycle reference between
+ // `ATen/core/Generator.h` and `ATen/core/Tensor.h`
+ void set_state(const at::Tensor& new_state);
+
+ at::Tensor get_state() const;
+
std::mutex& mutex() {
return impl_->mutex_;
}
@@ -130,4 +138,24 @@ Generator make_generator(Args&&... args) {
return Generator(c10::make_intrusive(std::forward(args)...));
}
+namespace detail {
+
+/**
+ * Helper function for checking the validity of new random generator
+ * state. Right now following conditions are checked:
+ *
+ * - The new state tensor must be a torch.ByteTensor
+ * - Data of the new state tensor must be contiguous
+ */
+static inline void check_rng_state(const c10::TensorImpl& new_state) {
+ TORCH_CHECK_TYPE(
+ new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
+ "RNG state must be a torch.ByteTensor"
+ );
+
+ TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
+}
+
+} // namespace detail
+
} // namespace at
diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp
index f84352ebee1f..58c35557018c 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.cpp
+++ b/aten/src/ATen/core/boxing/KernelFunction.cpp
@@ -57,25 +57,4 @@ bool KernelFunction::_equalsBoxedAndUnboxed(const KernelFunction& other) const {
unboxed_kernel_func_ == other.unboxed_kernel_func_;
}
-void KernelFunction::checkBoxedKernel(const OperatorHandle& opHandle) const {
- if (C10_UNLIKELY(boxed_kernel_func_ == nullptr)) {
- if (unboxed_kernel_func_ == nullptr) {
- TORCH_INTERNAL_ASSERT(
- false,
- "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction.",
- " opname: ",
- opHandle.operator_name(),
- " If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`.");
- } else {
- // TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this case should be impossible.
- TORCH_INTERNAL_ASSERT(
- false,
- "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call().",
- " opname: ",
- opHandle.operator_name(),
- " If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`.");
- }
- }
-}
-
} // namespace c10
diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h
index 6817907b12b1..ddbbd912777a 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.h
+++ b/aten/src/ATen/core/boxing/KernelFunction.h
@@ -123,26 +123,6 @@ class TORCH_API KernelFunction final {
template
static KernelFunction makeFromUnboxedFunctor(std::unique_ptr kernelFunctor);
- /**
- * Create a KernelFunction from an unboxed functor and prevent creation of an
- * unboxing-wrapper. This means that you cannot call this KernelFunction
- * using KernelFunction::callBoxed()
- *
- * This is necessary because our unboxing wrappers don't work for all types
- * yet, so if you want to use one of these types as function arguments,
- * you need to use makeFromUnboxedOnlyFunctor.
- *
- * Example:
- *
- * > class MyFunctor final {
- * > public:
- * > Tensor operator()(Tensor a, Tensor b) {...}
- * > };
- * > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::make_unique());
- */
- template
- static KernelFunction makeFromUnboxedOnlyFunctor(std::unique_ptr kernelFunctor);
-
/**
* Create a KernelFunction from an unboxed function.
* This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
@@ -158,23 +138,6 @@ class TORCH_API KernelFunction final {
template
static KernelFunction makeFromUnboxedFunction(FuncPtr);
- /**
- * Create a KernelFunction from an unboxed function and prevent creation of an
- * unboxing-wrapper. This means that you cannot call this KernelFunction
- * using KernelFunction::callBoxed()
- *
- * This is necessary because our unboxing wrappers don't work for all types
- * yet, so if you want to use one of these types as function arguments,
- * you need to use makeFromUnboxedOnlyFunctor.
- *
- * Example:
- *
- * > Tensor unboxed_func(Tensor a, Tensor b) {...}
- * > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction();
- */
- template
- static KernelFunction makeFromUnboxedOnlyFunction(FuncPtr);
-
/**
* Create a KernelFunction from an unboxed function.
* KernelFunction::makeFromUnboxedFunction is usually a better choice than
@@ -189,9 +152,6 @@ class TORCH_API KernelFunction final {
template
static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
- template
- static KernelFunction makeFromUnboxedOnlyRuntimeFunction(FuncType* func);
-
static KernelFunction makeFallthrough();
static KernelFunction makeAmbiguousAutogradOther();
static KernelFunction makeNamedNotSupported();
@@ -213,12 +173,6 @@ class TORCH_API KernelFunction final {
// For testing internal invariants only
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
- // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic. This can be done once https://github.com/pytorch/pytorch/issues/32366 is fixed.
- void setManuallyBoxedKernel_(InternalBoxedKernelFunction* func);
-
private:
explicit KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func);
@@ -226,8 +180,6 @@ class TORCH_API KernelFunction final {
template
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, Stack* stack);
- void checkBoxedKernel(const OperatorHandle& opHandle) const;
-
OperatorKernel* getFunctor_() const;
std::shared_ptr functor_;
diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h
index 82a65fa27ffb..b248e54a6f94 100644
--- a/aten/src/ATen/core/boxing/KernelFunction_impl.h
+++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h
@@ -23,8 +23,7 @@ inline void KernelFunction::make_boxed_function(OperatorKernel*, const OperatorH
}
inline bool KernelFunction::isValid() const {
- // TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this should only check boxed_kernel_func_.
- return boxed_kernel_func_ != nullptr || unboxed_kernel_func_ != nullptr;
+ return boxed_kernel_func_ != nullptr;
}
inline bool KernelFunction::isFallthrough() const {
@@ -32,7 +31,10 @@ inline bool KernelFunction::isFallthrough() const {
}
inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, Stack* stack) const {
- checkBoxedKernel(opHandle);
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ boxed_kernel_func_ != nullptr,
+ "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction."
+ );
(*boxed_kernel_func_)(functor_.get(), opHandle, stack);
}
@@ -111,21 +113,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr
-inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr kernelFunctor) {
- // TODO We want to get rid of kernels that have only an unboxed function pointer.
- // All kernels should have a boxed pointer.
-
- static_assert(guts::is_functor::value, "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor.");
- static_assert(std::is_base_of::value, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
-
- return KernelFunction(
- std::move(kernelFunctor),
- nullptr, // Don't create a boxed kernel for this
- reinterpret_cast(&impl::wrap_kernel_functor_unboxed::call)
- );
-}
-
template
inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) {
static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
@@ -144,26 +131,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr)
#endif
}
-template
-inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunction(FuncPtr func_ptr) {
- // TODO We want to get rid of kernels that have only an unboxed function pointer.
- // All kernels should have a boxed pointer.
- static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
- static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
- static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
-
-#if !defined(C10_MOBILE)
- return makeFromUnboxedOnlyFunctor::type> (
- guts::make_unique_base::type>()
- );
-#else
- // On mobile, we rather want to optimize for binary size than for performance,
- // so let's not inline the kernel into the wrapper but use makeFromUnboxedOnlyRuntimeFunction
- // instead.
- return makeFromUnboxedOnlyRuntimeFunction(func_ptr.func_ptr());
-#endif
-}
-
template
inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) {
static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
@@ -175,17 +142,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* f
);
}
-template
-inline KernelFunction KernelFunction::makeFromUnboxedOnlyRuntimeFunction(FuncType* func) {
- static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
- static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
- TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
-
- return makeFromUnboxedOnlyFunctor>>(
- guts::make_unique_base>>(func)
- );
-}
-
template
inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
@@ -212,14 +168,4 @@ inline std::enable_if_t>::value,
);
}
-inline void KernelFunction::setManuallyBoxedKernel_(InternalBoxedKernelFunction* func) {
- if (boxed_kernel_func_ == &fallthrough_kernel) {
- // special case no-op
- return;
- }
- TORCH_INTERNAL_ASSERT(boxed_kernel_func_ == nullptr, "Tried to set a manually boxed kernel for a kernel that already has a boxed kernel set.");
- TORCH_INTERNAL_ASSERT(unboxed_kernel_func_ != nullptr, "Tried to set a manually boxed kernel for an invalid KernelFunction.");
- boxed_kernel_func_ = func;
-}
-
}
diff --git a/aten/src/ATen/core/boxing/KernelFunction_test.cpp b/aten/src/ATen/core/boxing/KernelFunction_test.cpp
index 8ba50db14a2b..e17efab10ba5 100644
--- a/aten/src/ATen/core/boxing/KernelFunction_test.cpp
+++ b/aten/src/ATen/core/boxing/KernelFunction_test.cpp
@@ -544,26 +544,6 @@ TEST(KernelFunctionTest, givenUnboxedFunctor_withoutReturn_whenCallingUnboxed_th
kernels::expectUnboxedCallingWithoutReturnWorks(func);
}
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withoutReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectUnboxedCallingWithReturnWorks(func);
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withoutReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique()));
- kernels::expectUnboxedCallingWithoutReturnWorks(func);
-}
-
TEST(KernelFunctionTest, givenUnboxedFunction_withReturn_whenCallingBoxed_thenWorks) {
KernelFunction func = KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernels::unboxed_function_with_return));
kernels::expectBoxedCallingWithReturnWorks(func);
@@ -584,26 +564,6 @@ TEST(KernelFunctionTest, givenUnboxedFunction_withoutReturn_whenCallingUnboxed_t
kernels::expectUnboxedCallingWithoutReturnWorks(func);
}
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_with_return));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withoutReturn_whenCallingBoxed_thenFails) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_without_return));
- kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()");
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_with_return));
- kernels::expectUnboxedCallingWithReturnWorks(func);
-}
-
-TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withoutReturn_whenCallingUnboxed_thenWorks) {
- KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_without_return));
- kernels::expectUnboxedCallingWithoutReturnWorks(func);
-}
-
TEST(KernelFunctionTest, givenUnboxedRuntimeFunction_withReturn_whenCallingBoxed_thenWorks) {
KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&kernels::unboxed_function_with_return);
kernels::expectBoxedCallingWithReturnWorks(func);
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp
index 5e3e91afbb45..270cffaf6d1f 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.cpp
+++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp
@@ -295,12 +295,6 @@ void Dispatcher::checkInvariants() const {
}
}
-void Dispatcher::setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func) {
- std::lock_guard lock(mutex_);
- op.operatorIterator_->op.setManuallyBoxedKernel_(*this, func);
- // NB: Do not need to set manually boxed kernel for backend fallbacks
-}
-
std::vector Dispatcher::findDanglingImpls() const {
return operatorLookupTable_.read([&] (const ska::flat_hash_map& operatorLookupTable) -> std::vector {
std::vector opsWithDanglingImpls;
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 60f9f9bd0579..d83653f75363 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -182,12 +182,6 @@ class TORCH_API Dispatcher final {
*/
RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
- // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete setBoxedKernelFor_ once all operators work with the templated boxing logic
- void setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func);
-
// ------------------------------------------------------------------------
//
// Listeners on registrations
@@ -310,7 +304,9 @@ class TORCH_API OperatorHandle {
// smuggle in a kernel that is typed incorrectly). For everything
// in core library this won't happen, because all the static registrations
// will be done by the time a typed() handle is acquired.
+#if !defined C10_MOBILE
operatorIterator_->op.assertSignatureIsCorrect();
+#endif
return TypedOperatorHandle(operatorIterator_);
}
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
index f0d7bc6968ed..7c3698beeb06 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
@@ -21,7 +21,6 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
, schema_()
, dispatchTable_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
-, manuallyBoxedKernel_()
, kernels_()
, cpp_signature_()
, is_observed_(ObservedOperators::isObserved(name_))
@@ -122,10 +121,6 @@ std::list::iterator OperatorEntry::registerKernel(
);
}
- if (manuallyBoxedKernel_.has_value()) {
- kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_);
- }
-
k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
std::list::iterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
@@ -331,19 +326,6 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
}
}
-void OperatorEntry::setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func) {
- TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_);
- manuallyBoxedKernel_ = func;
-
- for (auto& kv : kernels_) {
- for (auto& k : kv.second) {
- k.kernel.setManuallyBoxedKernel_(func);
- }
- }
- // Refresh entries in dispatchTable_
- updateDispatchTableFull_(dispatcher);
-}
-
void OperatorEntry::checkInvariants() const {
if (schema_) {
TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState());
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h
index 5098fd0d8c28..44b8fac5661e 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.h
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.h
@@ -148,12 +148,6 @@ class TORCH_API OperatorEntry final {
const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; }
- // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic
- void setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func);
-
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
template
void assertSignatureIsCorrect() {
@@ -189,12 +183,6 @@ class TORCH_API OperatorEntry final {
std::array(DispatchKey::NumDispatchKeys)> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
- // This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
- // unboxing wrapper for aten operators. We still need those for some operators because not all work
- // with the templated unboxing logic yet.
- // TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic
- c10::optional manuallyBoxedKernel_;
-
// kernels_ stores all registered kernels for the corresponding dispatch key
// and catchAllKernels_ stores the catch-all kernels.
// If an operator library gets loaded that overwrites an already existing kernel,
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 320fa6294638..1223577c59c6 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -265,7 +265,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) {
TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr);
TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr);
return lhs.tag == rhs.tag &&
- lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
+ lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
}
IValue IValue::equals(const IValue& rhs) const {
@@ -325,17 +325,17 @@ size_t IValue::hash(const IValue& v) {
case Tag::None:
return 0;
case Tag::Bool:
- return c10::get_hash(v.payload.as_bool);
+ return c10::get_hash(v.payload.u.as_bool);
case Tag::Double:
- return c10::get_hash(v.payload.as_double);
+ return c10::get_hash(v.payload.u.as_double);
case Tag::Tensor:
// Tensor __hash__ is equivalent to `id()`, so take the pointer value of
// the tensor to emulate it
- return c10::get_hash(v.payload.as_int);
+ return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
case Tag::Storage:
- return c10::get_hash(v.payload.as_int);
+ return c10::get_hash(v.payload.u.as_int);
case Tag::Int:
- return c10::get_hash(v.payload.as_int);
+ return c10::get_hash(v.payload.u.as_int);
case Tag::String:
return c10::get_hash(v.toStringRef());
case Tag::Tuple:
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 4a7e15c4008b..ca68a8df46e1 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -131,10 +131,15 @@ struct Capsule {
// they are marked `@private`, which hides them on the doxygen documentation for
// this page.
-/// IValue (Interpreter Value) is a tagged union over the types supported by the
-/// TorchScript interpreter. IValues contain their values as an
-/// `IValue::Payload`, which holds primitive types (`int64_t`, `bool`, `double`,
-/// `Device`), as values and all other types as a `c10::intrusive_ptr`.
+/// IValue (Interpreter Value) is a tagged union over the types
+/// supported by the TorchScript interpreter. IValues contain their
+/// values as an `IValue::Payload`, which holds primitive types
+/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values,
+/// and all other types as a `c10::intrusive_ptr`. In order to
+/// optimize performance of the destructor and related operations by
+/// making the `Tensor` and `c10::intrusive_ptr` paths generate the
+/// same code, we represent a null `c10::intrusive_ptr` as
+/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`.
///
/// IValues are used as inputs to and outputs from the TorchScript interpreter.
/// To retrieve the value contained within an IValue, use the `.toX()` methods,
@@ -160,27 +165,35 @@ struct Capsule {
struct TORCH_API IValue final {
IValue(const IValue& rhs)
: IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) {
- if (is_intrusive_ptr) {
- c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
+ if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+ c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
}
}
- IValue(IValue&& rhs) noexcept : IValue() {
- swap(rhs);
+
+ IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) {
+ moveFrom(std::move(rhs));
}
+
/// @private [doxygen private]
~IValue() {
- if (is_intrusive_ptr) {
- c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
- }
+ destroy();
}
- IValue& operator=(IValue&& rhs) & noexcept {
- IValue(std::move(rhs)).swap(*this); // this also sets rhs to None
+
+ C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
+ if (&rhs == this) {
+ return *this;
+ }
+
+ destroy();
+ moveFrom(std::move(rhs));
return *this;
}
+
IValue& operator=(IValue const& rhs) & {
IValue(rhs).swap(*this);
return *this;
}
+
void dump() const;
/**
@@ -260,6 +273,13 @@ struct TORCH_API IValue final {
return false;
}
+ // Tensors should be compared based on internal storage
+ if (this->isTensor()) {
+ const auto& thisTensor = this->toTensor();
+ const auto& rhsTensor = rhs.toTensor();
+ return thisTensor.is_alias_of(rhsTensor);
+ }
+
if (!this->is_intrusive_ptr) {
// Primitive types don't alias anything
return false;
@@ -267,29 +287,49 @@ struct TORCH_API IValue final {
AT_ASSERT(rhs.is_intrusive_ptr);
- // Tensors should be compared based on internal storage
- if (this->isTensor()) {
- const auto thisTensor = this->toTensor();
- const auto rhsTensor = rhs.toTensor();
- return thisTensor.is_alias_of(rhsTensor);
- }
-
// Other types can be compared by their ptr value
- return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
+ return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
}
/// @private [doxygen private]
size_t use_count() const noexcept {
+ if (isTensor()) {
+ return payload.as_tensor.use_count();
+ }
+
if (!is_intrusive_ptr) {
return 1;
}
- return c10::raw::intrusive_ptr::use_count(payload.as_intrusive_ptr);
+ if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
+ return 0;
+ }
+ return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr);
}
/// @private [doxygen private]
void swap(IValue& rhs) noexcept {
- std::swap(payload, rhs.payload);
+ if (isTensor() && rhs.isTensor()) {
+ std::swap(payload.as_tensor, rhs.payload.as_tensor);
+ } else if (isTensor()) {
+ at::Tensor t = std::move(payload.as_tensor);
+ // As far as I can tell, omitting the usual explicit destructor call
+ // is not UB in and of itself, and it's a slight perf win. The
+ // destructor is a no-op, because the moved-from Tensor is
+ // effectively an intrusive_ptr in the null state, so we don't need
+ // the behavior for correctness reasons either. Leaving this
+ // explanatory comment, including commented-out destructor call, to
+ // make this abundantly clear.
+ //
+ // payload.as_tensor.~Tensor();
+ payload.u = rhs.payload.u;
+ new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
+ } else if (rhs.isTensor()) {
+ rhs.swap(*this);
+ return;
+ } else {
+ std::swap(payload.u, rhs.payload.u);
+ }
std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
std::swap(tag, rhs.tag);
}
@@ -298,21 +338,17 @@ struct TORCH_API IValue final {
// While some of these accessors could be generated through templates,
// we prefer to write them manually for clarity
- IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
- // Note: the undefined tensor is not refcounted, so while it
- // is tagged as a tensor, is_intrusive_ptr is set to false.
- // This is not an optional optimization: our incref call
- // *will not* do the right thing when called on an
- // undefined tensor.
- payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl();
+ IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) {
+ new (&payload.as_tensor) at::Tensor(std::move(t));
}
bool isTensor() const {
return Tag::Tensor == tag;
}
at::Tensor toTensor() &&;
- at::Tensor toTensor() const&;
+ at::Tensor& toTensor() &;
+ const at::Tensor& toTensor() const&;
at::TensorImpl* unsafeToTensorImpl() const {
- return static_cast(payload.as_intrusive_ptr);
+ return payload.as_tensor.unsafeGetTensorImpl();
}
IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast(s)) {
@@ -321,7 +357,7 @@ struct TORCH_API IValue final {
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined tensor.
- payload.as_intrusive_ptr = s.unsafeReleaseStorageImpl();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl());
}
bool isStorage() const {
return Tag::Storage == tag;
@@ -341,7 +377,7 @@ struct TORCH_API IValue final {
: tag(Tag::Blob), is_intrusive_ptr(true) {
// TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
// and store it as a Tensor instead.
- payload.as_intrusive_ptr = blob.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
}
/// @private [doxygen private]
@@ -397,14 +433,14 @@ struct TORCH_API IValue final {
// Double
IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) {
- payload.as_double = d;
+ payload.u.as_double = d;
}
bool isDouble() const {
return Tag::Double == tag;
}
double toDouble() const {
AT_ASSERT(isDouble());
- return payload.as_double;
+ return payload.u.as_double;
}
// Future
@@ -433,7 +469,7 @@ struct TORCH_API IValue final {
// Int
IValue(int64_t i) : tag(Tag::Int), is_intrusive_ptr(false) {
- payload.as_int = i;
+ payload.u.as_int = i;
}
// allow you to pass literals (3, 4) without ambiguity
@@ -445,7 +481,7 @@ struct TORCH_API IValue final {
int64_t toInt() const {
AT_ASSERT(isInt());
- return payload.as_int;
+ return payload.u.as_int;
}
// Bool
@@ -454,9 +490,9 @@ struct TORCH_API IValue final {
// Initializing entire payload stops valgrind's from reporting
// "jump or move depends on uninitialised value" in IValue copy constructor
// See https://github.com/pytorch/pytorch/issues/37117
- payload.as_int = b;
+ payload.u.as_int = b;
#else
- payload.as_bool = b;
+ payload.u.as_bool = b;
#endif
}
bool isBool() const {
@@ -464,7 +500,7 @@ struct TORCH_API IValue final {
}
bool toBool() const {
AT_ASSERT(isBool());
- return payload.as_bool;
+ return payload.u.as_bool;
}
// IntList
@@ -580,7 +616,7 @@ struct TORCH_API IValue final {
c10::intrusive_ptr toEnumHolder() const&;
// None
- IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
+ IValue() : tag(Tag::None), is_intrusive_ptr(false) {}
bool isNone() const {
return Tag::None == tag;
}
@@ -616,21 +652,21 @@ struct TORCH_API IValue final {
// Device
IValue(c10::Device d) : tag(Tag::Device), is_intrusive_ptr(false) {
- payload.as_device.type = d.type();
- payload.as_device.index = d.index();
+ payload.u.as_device.type = d.type();
+ payload.u.as_device.index = d.index();
}
bool isDevice() const {
return Tag::Device == tag;
}
c10::Device toDevice() const {
AT_ASSERT(isDevice());
- return c10::Device(payload.as_device.type, payload.as_device.index);
+ return c10::Device(payload.u.as_device.type, payload.u.as_device.index);
}
//Stream
IValue(c10::Stream stream)
: tag(Tag::Stream), is_intrusive_ptr(false) {
- payload.as_int = stream.pack();
+ payload.u.as_int = stream.pack();
}
c10::Stream toStream() &&;
c10::Stream toStream() const &;
@@ -659,7 +695,7 @@ struct TORCH_API IValue final {
// QScheme
IValue(at::QScheme qscheme) : tag(Tag::Int), is_intrusive_ptr(false) {
- payload.as_int = static_cast(qscheme);
+ payload.u.as_int = static_cast(qscheme);
}
at::QScheme toQScheme() const {
@@ -680,7 +716,7 @@ struct TORCH_API IValue final {
// This is not an optional optimization: our incref call
// *will not* do the right thing when called on an
// undefined generator.
- payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl());
}
bool isGenerator() const {
return Tag::Generator == tag;
@@ -749,14 +785,19 @@ struct TORCH_API IValue final {
const IValue& v);
bool isPtrType() const {
- return is_intrusive_ptr;
+ return (isTensor() && payload.as_tensor.defined()) || is_intrusive_ptr;
}
/// @private [doxygen private]
const void* internalToPointer() const {
TORCH_INTERNAL_ASSERT(
isPtrType(), "Can only call internalToPointer() for pointer types");
- return payload.as_intrusive_ptr;
+ if (isTensor()) {
+ return payload.as_tensor.unsafeGetTensorImpl();
+ } else {
+ return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()
+ ? payload.u.as_intrusive_ptr : nullptr;
+ }
}
TypePtr type() const;
@@ -770,7 +811,7 @@ struct TORCH_API IValue final {
}
// If it is not a Tensor, then two mutable IValues alias each other only
// if they are the same pointer.
- return val.payload.as_int;
+ return val.payload.u.as_int;
}
};
@@ -800,6 +841,10 @@ struct TORCH_API IValue final {
IValue deepcopy(HashAliasedIValueMap& memo) const;
private:
+ static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) {
+ return p ? p : static_cast(c10::UndefinedTensorImpl::singleton());
+ }
+
static bool ptrEqual(const IValue& lhs, const IValue& rhs);
// NOTE: IValue tags are intentionally private. In the future we may encode
// this value different (e.g. using NaN boxing), and this would make it more
@@ -822,24 +867,77 @@ struct TORCH_API IValue final {
class NullType = c10::detail::intrusive_target_default_null_type>
c10::intrusive_ptr toIntrusivePtr() const;
- void clearToNone() {
- payload.as_int = 0;
+ void destroy() {
+ // We carefully construct this call to both 1) avoid UB by using
+ // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable
+ // the compiler to generate the same code for each case. It is
+ // surprisingly difficult to get this right.
+ if (isTensor() || is_intrusive_ptr) {
+ c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr;
+ c10::intrusive_ptr::reclaim(p);
+ // No need to make this destructor call!
+ // payload.as_tensor.~Tensor();
+ }
+ }
+
+ C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
+ if (rhs.isTensor()) {
+ new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
+ // As far as I can tell, omitting the usual explicit destructor call
+ // is not UB in and of itself, and it's a slight perf win. The
+ // destructor is a no-op, because the moved-from Tensor is
+ // effectively an intrusive_ptr in the null state, so we don't need
+ // the behavior for correctness reasons either. Leaving this
+ // explanatory comment, including commented-out destructor call, to
+ // make this abundantly clear.
+ //
+ // rhs.payload.as_tensor.~Tensor();
+ } else {
+ payload.u = rhs.payload.u;
+ }
+ tag = rhs.tag;
+ is_intrusive_ptr = rhs.is_intrusive_ptr;
+ rhs.clearToNone();
+ }
+
+ void clearToNone() noexcept {
+ payload.u.as_int = 0;
tag = Tag::None;
is_intrusive_ptr = false;
}
union Payload {
- int64_t as_int;
- double as_double;
- bool as_bool;
- c10::intrusive_ptr_target* as_intrusive_ptr;
- struct {
- DeviceType type;
- DeviceIndex index;
- } as_device;
+ // We use a nested union here so that we can make the copy easy
+ // and efficient in the non-tensor (i.e., trivially copyable)
+ // case. Specifically, we do not have to do a switch-on-tag to
+ // figure out which union member to assign; we can just use
+ // TriviallyCopyablePayload::operator=.
+ union TriviallyCopyablePayload {
+ TriviallyCopyablePayload() : as_int(0) {}
+ int64_t as_int;
+ double as_double;
+ bool as_bool;
+ // Invariant: never nullptr; null state is represented as
+ // c10::UndefinedTensorImpl::singleton() for consistency of
+ // representation with Tensor.
+ c10::intrusive_ptr_target* as_intrusive_ptr;
+ struct {
+ DeviceType type;
+ DeviceIndex index;
+ } as_device;
+ } u;
+ at::Tensor as_tensor;
+ Payload() : u() {}
+ ~Payload() {}
};
- IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {}
+ IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) {
+ if (isTensor()) {
+ new (&payload.as_tensor) at::Tensor(p.as_tensor);
+ } else {
+ payload.u = p.u;
+ }
+ }
Payload payload;
Tag tag;
@@ -848,29 +946,36 @@ struct TORCH_API IValue final {
};
struct TORCH_API WeakIValue final {
- WeakIValue() : payload{0}, tag(IValue::Tag::None), is_intrusive_ptr(false) {}
+ WeakIValue() : tag(IValue::Tag::None), is_intrusive_ptr(false) {}
WeakIValue(const WeakIValue& rhs)
: payload(rhs.payload),
tag(rhs.tag),
is_intrusive_ptr(rhs.is_intrusive_ptr) {
- if (is_intrusive_ptr) {
+ if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
}
}
WeakIValue(const IValue& rhs)
- : payload(rhs.payload),
- tag(rhs.tag),
+ : tag(rhs.tag),
is_intrusive_ptr(rhs.is_intrusive_ptr) {
+ if (rhs.isTensor()) {
+ payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
+ is_intrusive_ptr = true;
+ } else {
+ payload = rhs.payload.u;
+ }
if (is_intrusive_ptr) {
- c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
+ if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+ c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
+ }
}
}
WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() {
swap(rhs);
}
~WeakIValue() {
- if (is_intrusive_ptr) {
+ if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr);
}
}
@@ -895,17 +1000,33 @@ struct TORCH_API WeakIValue final {
IValue lock() const {
if (!is_intrusive_ptr) {
- return IValue(payload, tag, false);
+ IValue::Payload newPayload;
+ newPayload.u = payload;
+ return IValue(newPayload, tag, false);
}
- auto temp = c10::weak_intrusive_ptr::reclaim(
- payload.as_intrusive_ptr);
- IValue::Payload pl;
- pl.as_intrusive_ptr = temp.lock().release();
- temp.release();
- if (!pl.as_intrusive_ptr) {
- return IValue();
+ if (IValue::Tag::Tensor == tag) {
+ auto temp = c10::weak_intrusive_ptr::reclaim(
+ static_cast(payload.as_intrusive_ptr));
+ c10::intrusive_ptr ip(temp.lock());
+ temp.release();
+ if (!ip) {
+ return IValue();
+ } else {
+ return IValue(at::Tensor(std::move(ip)));
+ }
} else {
- return IValue(pl, tag, true);
+ auto temp = c10::weak_intrusive_ptr::reclaim(
+ payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+ ? nullptr
+ : payload.as_intrusive_ptr);
+ IValue::Payload pl;
+ pl.u.as_intrusive_ptr = temp.lock().release();
+ temp.release();
+ if (!pl.u.as_intrusive_ptr) {
+ return IValue();
+ } else {
+ return IValue(pl, tag, true);
+ }
}
}
@@ -913,7 +1034,7 @@ struct TORCH_API WeakIValue final {
if (!is_intrusive_ptr) {
return 1;
}
- auto temp = c10::weak_intrusive_ptr::reclaim(
+ auto temp = c10::weak_intrusive_ptr::reclaim(
payload.as_intrusive_ptr);
size_t result = temp.use_count();
temp.release();
@@ -924,7 +1045,7 @@ struct TORCH_API WeakIValue final {
if (!is_intrusive_ptr) {
return 1;
}
- auto temp = c10::weak_intrusive_ptr::reclaim(
+ auto temp = c10::weak_intrusive_ptr::reclaim(
payload.as_intrusive_ptr);
size_t result = temp.weak_use_count();
temp.release();
@@ -935,7 +1056,8 @@ struct TORCH_API WeakIValue final {
}
private:
- IValue::Payload payload;
+ using Payload = IValue::Payload::TriviallyCopyablePayload;
+ Payload payload;
IValue::Tag tag;
bool is_intrusive_ptr;
};
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index 89c8e669c138..b96f4b834989 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -48,14 +48,18 @@ struct tagged_capsule {
template
c10::intrusive_ptr IValue::moveToIntrusivePtr() {
auto t = c10::intrusive_ptr::reclaim(
- static_cast(payload.as_intrusive_ptr));
+ payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+ ? NullType::singleton()
+ : static_cast(payload.u.as_intrusive_ptr));
clearToNone();
return t;
}
template
c10::intrusive_ptr IValue::toIntrusivePtr() const {
auto r = c10::intrusive_ptr::reclaim(
- static_cast(payload.as_intrusive_ptr));
+ payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+ ? NullType::singleton()
+ : static_cast(payload.u.as_intrusive_ptr));
auto p = r;
r.release();
return p;
@@ -131,12 +135,26 @@ inline c10::intrusive_ptr IValue::toEnumHolder() const& {
}
inline at::Tensor IValue::toTensor() && {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
- return at::Tensor(
- moveToIntrusivePtr());
+ auto result = std::move(payload.as_tensor);
+ // As far as I can tell, omitting the usual explicit destructor call
+ // is not UB in and of itself, and it's a slight perf win. The
+ // destructor is a no-op, because the moved-from Tensor is
+ // effectively an intrusive_ptr in the null state, so we don't need
+ // the behavior for correctness reasons either. Leaving this
+ // explanatory comment, including commented-out destructor call, to
+ // make this abundantly clear.
+ //
+ // payload.as_tensor.~Tensor();
+ clearToNone();
+ return result;
}
-inline at::Tensor IValue::toTensor() const& {
+inline at::Tensor& IValue::toTensor() & {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
- return at::Tensor(toIntrusivePtr());
+ return payload.as_tensor;
+}
+inline const at::Tensor& IValue::toTensor() const& {
+ AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
+ return payload.as_tensor;
}
inline c10::Storage IValue::toStorage() && {
AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
@@ -148,10 +166,10 @@ inline c10::Storage IValue::toStorage() const& {
return c10::Storage(toIntrusivePtr());
}
inline c10::Stream IValue::toStream() && {
- return c10::Stream::unpack(payload.as_int);
+ return c10::Stream::unpack(payload.u.as_int);
}
inline c10::Stream IValue::toStream() const& {
- return c10::Stream::unpack(payload.as_int);
+ return c10::Stream::unpack(payload.u.as_int);
}
inline c10::intrusive_ptr IValue::toBlob() && {
AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
@@ -713,7 +731,8 @@ using _guarded_unsigned_long = std::conditional_t<
inline const ivalue::Object& IValue::toObjectRef() const {
AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
- return *static_cast(payload.as_intrusive_ptr);
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
+ return *static_cast(payload.u.as_intrusive_ptr);
}
// note: when adding a DEFINE_TO case here you should also add a
@@ -729,6 +748,7 @@ inline const ivalue::Object& IValue::toObjectRef() const {
inline type IValue::to() const& { \
return this->method_name(); \
}
+
DEFINE_TO(at::Tensor, toTensor)
DEFINE_TO(at::Storage, toStorage)
DEFINE_TO(c10::Stream, toStream)
@@ -980,8 +1000,11 @@ inline c10::List IValue::toIntList() const& {
}
inline std::vector IValue::toIntVector() const {
AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toIntVector on null intrusive_ptr IValue");
return createVectorFromList(
- static_cast(payload.as_intrusive_ptr));
+ static_cast(payload.u.as_intrusive_ptr));
}
inline c10::List IValue::toDoubleList() && {
AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
@@ -993,8 +1016,11 @@ inline c10::List IValue::toDoubleList() const& {
}
inline std::vector IValue::toDoubleVector() const {
AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toDoubleVector on null intrusive_ptr IValue");
return createVectorFromList(
- static_cast(payload.as_intrusive_ptr));
+ static_cast(payload.u.as_intrusive_ptr));
}
inline c10::List IValue::toBoolList() && {
AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
@@ -1014,8 +1040,11 @@ inline c10::List IValue::toTensorList() const& {
}
inline std::vector IValue::toTensorVector() const {
AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toTensorVector on null intrusive_ptr IValue");
return createVectorFromList(
- static_cast(payload.as_intrusive_ptr));
+ static_cast(payload.u.as_intrusive_ptr));
}
inline c10::List IValue::toList() && {
AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
@@ -1027,7 +1056,10 @@ inline c10::List IValue::toList() const& {
}
inline c10::ArrayRef IValue::toListRef() const {
AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
- return static_cast(payload.as_intrusive_ptr)
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toListRef on null intrusive_ptr IValue");
+ return static_cast(payload.u.as_intrusive_ptr)
->list;
}
inline c10::Dict IValue::toGenericDict() && {
@@ -1049,7 +1081,7 @@ inline c10::intrusive_ptr IValue::toTuple() const& {
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Tuple), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
template <
typename... Args,
@@ -1065,14 +1097,14 @@ inline IValue::IValue(const std::tuple& t)
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::String), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(std::string v)
: IValue(ivalue::ConstantString::create(std::move(v))) {}
inline IValue::IValue(c10::impl::GenericList v)
: tag(Tag::GenericList), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.impl_.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
}
template >
@@ -1104,7 +1136,7 @@ inline IValue::IValue(std::array v) : IValue(c10::List()) {
inline IValue::IValue(c10::impl::GenericDict v)
: tag(Tag::GenericDict), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.impl_.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
}
template
inline IValue::IValue(c10::Dict v)
@@ -1131,17 +1163,17 @@ inline IValue::IValue(c10::nullopt_t) : IValue() {}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Object), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::PyObject), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Enum), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue IValue::make_capsule(
@@ -1149,7 +1181,7 @@ inline IValue IValue::make_capsule(
IValue iv;
iv.tag = Tag::Capsule;
iv.is_intrusive_ptr = true;
- iv.payload.as_intrusive_ptr = blob.release();
+ iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
return iv;
}
@@ -1170,30 +1202,33 @@ IValue::IValue(c10::intrusive_ptr custom_class) {
auto ivalue_obj = c10::ivalue::Object::create(
c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
- payload.as_intrusive_ptr = ivalue_obj.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release());
tag = Tag::Object;
is_intrusive_ptr = true;
}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Future), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::RRef), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline IValue::IValue(c10::intrusive_ptr v)
: tag(Tag::Quantizer), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+ payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}
inline const std::string& IValue::toStringRef() const {
AT_ASSERT(isString(), "Expected String but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toStringRef on null intrusive_ptr IValue");
return static_cast(
- payload.as_intrusive_ptr)
+ payload.u.as_intrusive_ptr)
->string();
}
inline c10::optional> IValue::
@@ -1202,8 +1237,11 @@ inline c10::optional> IValue::
return c10::nullopt;
}
AT_ASSERT(isString(), "Expected optional but got ", tagKind());
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+ "called toOptionalStringRef on null intrusive_ptr IValue");
return std::reference_wrapper(
- static_cast(payload.as_intrusive_ptr)
+ static_cast(payload.u.as_intrusive_ptr)
->string());
}
@@ -1241,15 +1279,13 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const {
// for bool type, do equality check
return this->toBool() == rhs.toBool();
} else if (this->isTensor() && rhs.isTensor()) {
- // for tensor type, just check the as_intrusive_ptr since is_intrusive_ptr
- // is false for undefined tensor
- return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
+ return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
} else if (this->isTensor() && rhs.isNone()) {
// special case: undefined tensor and None are the same identity
- return !this->is_intrusive_ptr;
+ return !this->payload.as_tensor.defined();
} else if (this->isNone() && rhs.isTensor()) {
// special case: undefined tensor and None are the same identity
- return !rhs.is_intrusive_ptr;
+ return !rhs.payload.as_tensor.defined();
} else if (this->isInt() && rhs.isInt()) {
return this->toInt() == rhs.toInt();
} else if (this->isDouble() && rhs.isDouble()) {
@@ -1260,7 +1296,7 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const {
// for objects holding in IValue, do shallow compare on pointer address to
// testify the identity
return this->is_intrusive_ptr && rhs.is_intrusive_ptr &&
- this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
+ this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
}
}
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index a3ae813616e0..7d3890f582b8 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -2370,19 +2370,19 @@ struct TORCH_API AnyClassType : public Type {
inline bool IValue::isDoubleList() const {
// note: avoids calling type() to avoid extra referencing counting for the returned type.
- return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == FloatType::Kind;
+ return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == FloatType::Kind;
}
inline bool IValue::isTensorList() const {
- return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == TensorType::Kind;
+ return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == TensorType::Kind;
}
inline bool IValue::isIntList() const {
- return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == IntType::Kind;
+ return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == IntType::Kind;
}
inline bool IValue::isBoolList() const {
- return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == BoolType::Kind;
+ return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == BoolType::Kind;
}
template<>
diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h
index 37da9ad7ef8d..e5a6d48340cf 100644
--- a/aten/src/ATen/core/jit_type_base.h
+++ b/aten/src/ATen/core/jit_type_base.h
@@ -152,6 +152,20 @@ struct TORCH_API Type : std::enable_shared_from_this {
return nullptr;
}
template
+ T* castRaw() {
+ if (T::Kind == kind()) {
+ return static_cast(this);
+ }
+ return nullptr;
+ }
+ template
+ const T* castRaw() const {
+ if (T::Kind == kind()) {
+ return static_cast(this);
+ }
+ return nullptr;
+ }
+ template
std::shared_ptr expect() {
auto r = cast();
AT_ASSERT(r);
@@ -163,6 +177,18 @@ struct TORCH_API Type : std::enable_shared_from_this {
AT_ASSERT(r);
return r;
}
+ template
+ T& expectRef() {
+ auto* r = castRaw();
+ AT_ASSERT(r);
+ return *r;
+ }
+ template
+ const T& expectRef() const {
+ auto* r = castRaw();
+ AT_ASSERT(r);
+ return *r;
+ }
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index 6259578fdac8..56afe8ca7fb5 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -1909,7 +1909,7 @@ TEST(NewOperatorRegistrationTest, CppFunction) {
m.def("fn3", [](const Tensor& x) { return x; });
// These require explicit schema
m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough());
- m.def("fn5(Tensor x) -> Tensor", CppFunction::makeUnboxedOnly(dummy_fn));
+ m.def("fn5(Tensor x) -> Tensor", CppFunction::makeFromUnboxedFunction(dummy_fn));
m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
}
diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
index 8a5e4f48e0c0..f0572bb6d809 100644
--- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
+++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
@@ -130,6 +130,67 @@ uint64_t CUDAGeneratorImpl::seed() {
return random;
}
+/**
+ * Gets the current internal state of CUDAGeneratorImpl. The internal
+ * state is returned as a CPU byte tensor.
+ */
+c10::intrusive_ptr CUDAGeneratorImpl::get_state() const {
+ // The RNG state comprises the seed, and an offset used for Philox.
+ // The following line is just here for BC reason. sizeof curandStateMtgp32 is 4120.
+ // It used to be static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
+ // MAX_NUM_BLOCKS was 200 and sizeof(curandStateMtgp32) is 4120. Hardcoding these numbers here
+ // because this is just host side code and we don't want to worry about linking with cuda
+ static const size_t states_size = 200 * sizeof(4120);
+ static const size_t seed_size = sizeof(uint64_t);
+ static const size_t offset_size = sizeof(int64_t);
+ static const size_t total_size = states_size + seed_size + offset_size;
+
+ auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
+ auto rng_state = state_tensor.data_ptr();
+ // since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1
+ // gen_states in THCGenerator struct was an array of curandStateMtgp32s.
+ memset(rng_state, -1, states_size);
+ auto current_seed = this->current_seed();
+ auto offset = static_cast(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic
+ memcpy(rng_state + states_size, ¤t_seed, seed_size);
+ memcpy(rng_state + states_size + seed_size, &offset, offset_size);
+
+ return state_tensor.getIntrusivePtr();
+}
+
+/**
+ * Sets the internal state of CUDAGeneratorImpl. The new internal state
+ * must be a strided CPU byte tensor and have appropriate size. See
+ * comments of CUDAGeneratorImpl::state for information about the layout
+ * and size of the internal state.
+ */
+void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
+ static const size_t states_size = 200 * sizeof(4120); // this line is just here for BC reason
+ static const size_t seed_size = sizeof(uint64_t);
+ static const size_t offset_size = sizeof(int64_t);
+ static const size_t total_size = states_size + seed_size + offset_size;
+
+ detail::check_rng_state(new_state);
+
+ bool no_philox_seed = false;
+ auto new_state_size = new_state.numel();
+ if (new_state_size == total_size - offset_size) {
+ no_philox_seed = true;
+ } else {
+ TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
+ }
+
+ uint64_t input_seed;
+ auto new_rng_state = new_state.data();
+ memcpy(&input_seed, new_rng_state + states_size, seed_size);
+ this->set_current_seed(input_seed);
+ int64_t philox_offset = 0;
+ if (!no_philox_seed) {
+ memcpy(&philox_offset, new_rng_state + states_size + seed_size, offset_size);
+ }
+ this->set_philox_offset_per_thread(static_cast(philox_offset));
+}
+
/**
* Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10
*
@@ -143,7 +204,7 @@ void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
/**
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
*/
-uint64_t CUDAGeneratorImpl::philox_offset_per_thread() {
+uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::philox_offset_per_thread");
return philox_offset_per_thread_;
}
diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp
index f38860e8ef13..b75ef8219b1c 100644
--- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp
+++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp
@@ -369,6 +369,11 @@ int CUDAHooks::getNumGPUs() const {
return at::cuda::device_count();
}
+void CUDAHooks::deviceSynchronize(int64_t device_index) const {
+ at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
+ c10::cuda::device_synchronize();
+}
+
// Sigh, the registry doesn't support namespaces :(
using at::CUDAHooksRegistry;
using at::RegistererCUDAHooksRegistry;
diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h
index dff8913b153f..abef2e7ff835 100644
--- a/aten/src/ATen/cuda/detail/CUDAHooks.h
+++ b/aten/src/ATen/cuda/detail/CUDAHooks.h
@@ -38,6 +38,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
void cuFFTClearPlanCache(int64_t device_index) const override;
int getNumGPUs() const override;
+ void deviceSynchronize(int64_t device_index) const override;
};
}}} // at::cuda::detail
diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h
index af4eb6fd0739..afe88761d88f 100644
--- a/aten/src/ATen/detail/CUDAHooksInterface.h
+++ b/aten/src/ATen/detail/CUDAHooksInterface.h
@@ -181,6 +181,10 @@ struct TORCH_API CUDAHooksInterface {
virtual int getNumGPUs() const {
return 0;
}
+
+ virtual void deviceSynchronize(int64_t device_index) const {
+ TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
+ }
};
// NB: dummy argument to suppress "ISO C++11 requires at least one argument
diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp
index ef0c2e2509c1..413ea32acdef 100644
--- a/aten/src/ATen/native/Distributions.cpp
+++ b/aten/src/ATen/native/Distributions.cpp
@@ -118,7 +118,7 @@ DEFINE_DISPATCH(bernoulli_tensor_stub);
DEFINE_DISPATCH(bernoulli_scalar_stub);
DEFINE_DISPATCH(cauchy_stub);
DEFINE_DISPATCH(exponential_stub);
-DEFINE_DISPATCH(multinomial_stub);
+DEFINE_DISPATCH(multinomial_with_replacement_stub);
DEFINE_DISPATCH(geometric_stub);
DEFINE_DISPATCH(log_normal_stub);
DEFINE_DISPATCH(uniform_stub);
@@ -497,8 +497,10 @@ Tensor& multinomial_out(
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
// Half is not supported on CPU.
- if (!with_replacement &&
- !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) {
+ TORCH_CHECK(
+ !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half),
+ "multinomial is not implemented for half on CPU");
+ if (!with_replacement) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
@@ -537,13 +539,8 @@ Tensor& multinomial_out(
return result;
}
- multinomial_stub(
- result.device().type(),
- result,
- self,
- n_sample,
- with_replacement,
- gen);
+ multinomial_with_replacement_stub(
+ result.device().type(), result, self, n_sample, gen);
return result;
}
diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h
index 071460b090cd..8b5d65a8a60f 100644
--- a/aten/src/ATen/native/Pool.h
+++ b/aten/src/ATen/native/Pool.h
@@ -72,7 +72,7 @@ pool2d_shape_check(
TORCH_CHECK(input.numel() > 0 && (ndim == 3 || ndim == 4),
"non-empty 3D or 4D input tensor expected but got ndim: ", ndim);
TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
- "pad should be smaller than half of kernel size, but got ",
+ "pad should be smaller than or equal to half of kernel size, but got ",
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
@@ -172,7 +172,7 @@ pool3d_shape_check(
}
TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
- "pad should be smaller than half of kernel size, but got "
+ "pad should be smaller than or equal to half of kernel size, but got "
"kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp
index c8eb3cc99a01..289d1128d2f9 100644
--- a/aten/src/ATen/native/SpectralOps.cpp
+++ b/aten/src/ATen/native/SpectralOps.cpp
@@ -102,9 +102,12 @@ Tensor resize_fft_input(Tensor x, IntArrayRef dims, IntArrayRef sizes) {
}
// Complex to real FFT
-Tensor fft_c2r(Tensor input, c10::optional n_opt,
+Tensor fft_c2r(c10::string_view function_name,
+ Tensor out, Tensor input, c10::optional n_opt,
int64_t unwrapped_dim, c10::optional norm_str,
bool forward) {
+ TORCH_CHECK(!out.defined() || out.is_floating_point(), function_name,
+ " expects a floating point output tensor, but got ", out.scalar_type());
input = promote_tensor_fft(input, /*require_complex=*/true);
const auto input_dim = input.dim();
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
@@ -118,14 +121,22 @@ Tensor fft_c2r(Tensor input, c10::optional n_opt,
// FIXME: _fft does not support complex_output=false with inverse=false
input = at::conj(input);
}
- return at::_fft_c2r(input, dim, static_cast(norm), n);
+ if (out.defined()) {
+ return at::_fft_c2r_out(out, input, dim, static_cast(norm), n);
+ } else {
+ return at::_fft_c2r(input, dim, static_cast(norm), n);
+ }
}
// Real to complex FFT
-Tensor fft_r2c(Tensor input, c10::optional n_opt,
+Tensor fft_r2c(c10::string_view function_name,
+ Tensor out, Tensor input, c10::optional n_opt,
int64_t unwrapped_dim, c10::optional norm_str,
bool forward, bool onesided) {
- TORCH_CHECK(!input.is_complex(), "Expected a real input tensor to FFT");
+ TORCH_CHECK(!input.is_complex(), function_name,
+ " expects a real input tensor, but got ", input.scalar_type());
+ TORCH_CHECK(!out.defined() || out.is_complex(), function_name,
+ " expects a complex output tensor, but got ", out.scalar_type());
input = promote_tensor_fft(input);
const auto input_dim = input.dim();
const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim);
@@ -136,19 +147,29 @@ Tensor fft_r2c(Tensor input, c10::optional n_opt,
}
const auto norm = norm_from_string(norm_str, forward);
- auto out = at::_fft_r2c(input, dim, static_cast(norm), onesided);
+
+ Tensor ret;
+ if (out.defined() && forward) {
+ ret = at::_fft_r2c_out(out, input, dim, static_cast(norm), onesided);
+ } else {
+ ret = at::_fft_r2c(input, dim, static_cast