diff --git a/.ci/manywheel/build_cpu.sh b/.ci/manywheel/build_cpu.sh index c3ddba33cd946..ad51810e06a2b 100755 --- a/.ci/manywheel/build_cpu.sh +++ b/.ci/manywheel/build_cpu.sh @@ -75,9 +75,11 @@ if [[ "$ARCH" == "aarch64" ]]; then # ARM system libraries DEPS_LIST+=( "/usr/lib64/libgfortran.so.5" + "/opt/OpenBLAS/lib/libopenblas.so.0" ) DEPS_SONAME+=( "libgfortran.so.5" + "libopenblas.so.0" ) fi diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 8462dd2aa4e55..616dfd88ce812 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -07b6cbde121417a70e4dc871adb6d27030e0ce3f +ee1a1350eb37804b94334768f328144f058f14e9 diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 191c21631f662..803ba72d9ac92 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a +94631807d22c09723dd006f7be5beb649d5f88d0 diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 2b9558197bdcb..2d7ca10433d6a 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -245,6 +245,9 @@ class TORCH_API TensorBase { size_t weak_use_count() const noexcept { return impl_.weak_use_count(); } + bool is_uniquely_owned() const noexcept { + return impl_.is_uniquely_owned(); + } std::string toString() const; diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 86e960cc1ab4a..01d10f61da692 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); TORCH_CUDA_CPP_API void clearCublasWorkspaces(); -TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublas_handle_stream_to_workspace(); -TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublaslt_handle_stream_to_workspace(); +struct WorkspaceMapWithMutex { + std::map, at::DataPtr> map; + std::shared_mutex mutex; +}; + +TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace(); +TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace(); TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize(); TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace(); diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 6175e69827e2f..9ec3acf4cd29e 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) { // - Comments of @soumith copied from cuDNN handle pool implementation #ifdef NO_CUDNN_DESTROY_HANDLE #else - cublasDestroy(handle); + cublasDestroy(handle); #endif } @@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool, at::DataPtr>& cublas_handle_stream_to_workspace() { - static auto& instance = *new std::map, at::DataPtr>; +WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() { + static auto& instance = *new WorkspaceMapWithMutex; return instance; } -std::map, at::DataPtr>& cublaslt_handle_stream_to_workspace() { - static auto& instance = *new std::map, at::DataPtr>; +WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() { + static auto& instance = *new WorkspaceMapWithMutex; return instance; } void clearCublasWorkspaces() { - cublas_handle_stream_to_workspace().clear(); - cublaslt_handle_stream_to_workspace().clear(); + { + auto& workspace = cublas_handle_stream_to_workspace(); + std::unique_lock lock(workspace.mutex); + workspace.map.clear(); + } + { + auto& workspace = cublaslt_handle_stream_to_workspace(); + std::unique_lock lock(workspace.mutex); + workspace.map.clear(); + } } size_t parseChosenWorkspaceSize() { @@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() { return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize()); } +void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) { + cudaStream_t _stream = stream; + auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); + + auto& workspace = cublas_handle_stream_to_workspace(); + + size_t workspace_size = getChosenWorkspaceSize(); + + // Fast path: check if workspace already exists + { + std::shared_lock lock(workspace.mutex); + auto workspace_it = workspace.map.find(key); + if (workspace_it != workspace.map.end()) { + TORCH_CUDABLAS_CHECK(cublasSetWorkspace( + handle, workspace_it->second.get(), workspace_size)); + return; + } + } + + // Slow path: allocate workspace outside the lock + auto new_workspace = getNewWorkspace(); + + // Insert with lock (double-check in case another thread inserted while we + // were allocating) + { + std::unique_lock lock(workspace.mutex); + auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first; + TORCH_CUDABLAS_CHECK( + cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size)); + } +} + void* getCUDABlasLtWorkspace() { #ifndef USE_ROCM static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true; @@ -241,8 +281,10 @@ void* getCUDABlasLtWorkspace() { auto stream = c10::cuda::getCurrentCUDAStream(); cudaStream_t _stream = stream; auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); - auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); - TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); + auto& workspace = at::cuda::cublas_handle_stream_to_workspace(); + std::shared_lock lock(workspace.mutex); + auto workspace_it = workspace.map.find(key); + TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end()); return workspace_it->second.mutable_get(); } #endif @@ -250,11 +292,29 @@ void* getCUDABlasLtWorkspace() { auto stream = c10::cuda::getCurrentCUDAStream(); cudaStream_t _stream = stream; auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); - auto workspace_it = cublaslt_handle_stream_to_workspace().find(key); - if (workspace_it == cublaslt_handle_stream_to_workspace().end()) { - workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()}); + + auto& workspace = cublaslt_handle_stream_to_workspace(); + + // Fast path: check if workspace already exists + { + std::shared_lock lock(workspace.mutex); + auto workspace_it = workspace.map.find(key); + if (workspace_it != workspace.map.end()) { + return workspace_it->second.mutable_get(); + } + } + + // Slow path: allocate workspace outside the lock + auto new_workspace = getNewCUDABlasLtWorkspace(); + + // Insert with lock (double-check in case another thread inserted while we + // were allocating) + { + std::unique_lock lock(workspace.mutex); + auto workspace_it = + workspace.map.try_emplace(key, std::move(new_workspace)).first; + return workspace_it->second.mutable_get(); } - return workspace_it->second.mutable_get(); } cublasHandle_t getCurrentCUDABlasHandle() { @@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() { // will allocate memory dynamically (even if they're cheap) outside // PyTorch's CUDA caching allocator. It's possible that CCA used up // all the memory and cublas's cudaMallocAsync will return OOM - cudaStream_t _stream = stream; - auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); - auto workspace_it = cublas_handle_stream_to_workspace().find(key); - if (workspace_it == cublas_handle_stream_to_workspace().end()) { - workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()}); - } - TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize())); + setWorkspaceForHandle(handle, stream); + #if !defined(USE_ROCM) // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup // FP32 data type calculations based on the value of the allow_tf32 flag. diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 2754d70cac013..75a4d357a1c0b 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -296,7 +296,7 @@ template bool launchGemmAndBiasCublasLt( // args contains result which is modified cublasCommonArgs& args, - const Tensor& self, + const std::optional& self, const Scalar& alpha, Activation activation = Activation::None ) { @@ -304,12 +304,8 @@ bool launchGemmAndBiasCublasLt( // or when it can be squeezed to 1D. // self_ptr == nullptr implies ignore bias epilogue // and use standard gemm-like API. - const auto* self_ptr = [&]() -> auto { - if (self.dim() == 1 || self.squeeze().dim() == 1) { - return self.const_data_ptr(); - } - return static_cast(nullptr); - }(); + const auto* self_ptr = self.has_value() ? self.value().const_data_ptr() : static_cast(nullptr); + const auto tuning_ctx = at::cuda::tunable::getTuningContext(); if (tuning_ctx->IsTunableOpEnabled()) { @@ -392,35 +388,30 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override; #ifdef USE_ROCM // Conditioned on the device index, which is not persistent - disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt; + disable_addmm_cuda_lt = disable_addmm_cuda_lt || isGloballyDisabledAddmmCudaLt(self.device()); #endif // Condition on the input - disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt; - // } + disable_addmm_cuda_lt = disable_addmm_cuda_lt || !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation); at::ScalarType scalar_type = mat1.scalar_type(); bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float; + #ifdef USE_ROCM + disable_addmm_cuda_lt = disable_addmm_cuda_lt || is_float_output_with_half_input; + #endif + + bool use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt; + // for float output with half input cublasLT with bias produces wrong results + use_bias_ptr_lt &= !is_float_output_with_half_input; + // Handle result/self shapes if (!result.is_same(self)) { at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]}); - // We use bias ptr in the Lt path only when bias is 1D - const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt; - const auto self_maybe_expanded = [&]() -> c10::MaybeOwned { - if (!use_bias_ptr_lt) { - // We do expand self even before - // check for beta != 0.0 to make sure that - // test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_* - // runs green. - return expand_size(self, result.sizes(), "addmm"); - } - return c10::MaybeOwned::borrowed(self); - }(); - // We do not copy bias only when we need the bias ptr + // We do not copy bias only when we need the bias ptr if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) { // NOTE: self should broadcast over result - at::native::copy_(result, *self_maybe_expanded); + at::native::copy_(result, *expand_size(self, result.sizes(), "addmm")); } } @@ -468,7 +459,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { - lt_success = launchGemmAndBiasCublasLt(args, self, alpha, activation); + lt_success = launchGemmAndBiasCublasLt(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation); } ); #endif @@ -480,7 +471,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { - lt_success = launchGemmAndBiasCublasLt(args, self, alpha, activation); + lt_success = launchGemmAndBiasCublasLt(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation); } ); } // end is_float_output_with_half_input @@ -936,7 +927,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { return _int_mm_out_cuda(self, mat2, result); } -static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional& self_baddbmm = std::nullopt) { +static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, const std::optional& self_baddbmm = std::nullopt) { // ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); @@ -960,7 +951,7 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat (out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)), "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"); - if (!is_bmm && self_baddbmm.has_value()) { + if (self_baddbmm.has_value()) { const auto& self = self_baddbmm.value(); TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor"); TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output"); @@ -968,15 +959,12 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat } Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) { - IntArrayRef batch1_sizes = batch1.sizes(); - IntArrayRef batch2_sizes = batch2.sizes(); - - Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype)); + Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype)); return _bmm_out_dtype_cuda(batch1, batch2, out_dtype, out); } Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) { - baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true); + baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype); Scalar beta(0.0); Scalar alpha(1.0); { @@ -988,14 +976,16 @@ Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at } Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) { - // We need to copy the tensor - Tensor out = self.clone().to(self.options().dtype(out_dtype)); - - return _baddbmm_out_dtype_cuda(out, batch1, batch2, out_dtype, beta, alpha, out); + TORCH_CHECK(self.scalar_type() == out_dtype || self.scalar_type() == batch1.dtype(), + "self dtype must match either out_dtype or batch1 dtype"); + Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype)); + return _baddbmm_out_dtype_cuda(self, batch1, batch2, out_dtype, beta, alpha, out); } Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) { - baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self); + baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, out); + // We need to copy the tensor + out.copy_(self); { NoNamesGuard guard; baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha); @@ -1030,24 +1020,27 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca } Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) { - Tensor result = at::empty(self.sizes(), self.options().dtype(out_dtype)); + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); + Tensor result = at::empty({mat1.size(0), mat2.size(1)}, self.options().dtype(out_dtype)); return _addmm_dtype_out_cuda(self, mat1, mat2, out_dtype, beta, alpha, result); } Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) { - TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type()); - TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type()); +// repeat dimensionality checks for direct calls to `out` overload TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); + TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type()); + TORCH_CHECK(out_dtype == mat1.scalar_type() || + (out_dtype == at::ScalarType::Float && (mat1.scalar_type() == at::ScalarType::Half || mat1.scalar_type() == at::ScalarType::BFloat16)), + "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"); TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor"); - TORCH_CHECK(out_dtype == self.scalar_type() || - (out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)), - "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"); - TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor"); + TORCH_CHECK(out_dtype == self.scalar_type() || self.scalar_type() == mat1.scalar_type(), + "self dtype must match either out_dtype or mat1 dtype"); addmm_out_cuda_impl(out, self, mat1, mat2, beta, alpha); diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 5c8b98105bb26..fd406829707a1 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -5,69 +5,11 @@ #include #endif -// ROCm 6.3 is planned to have these functions, but until then here they are. #if defined(USE_ROCM) #include #include #include - -__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { -#if (defined(__gfx942__)) && \ - __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) - typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; - static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); - union { - __hip_bfloat162_raw bf162_raw; - vec_short2 vs2; - } u{static_cast<__hip_bfloat162_raw>(value)}; - u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2); - return static_cast<__hip_bfloat162>(u.bf162_raw); -#else - static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw)); - union u_hold { - __hip_bfloat162_raw h2r; - unsigned int u32; - }; - u_hold old_val, new_val; - old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - do { - new_val.h2r = __hadd2(old_val.h2r, value); - } while (!__hip_atomic_compare_exchange_strong( - (unsigned int*)address, &old_val.u32, new_val.u32, - __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); - return old_val.h2r; -#endif -} - -__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) { -#if (defined(__gfx942__)) && \ - __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16) - // The api expects an ext_vector_type of half - typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; - static_assert(sizeof(vec_fp162) == sizeof(__half2_raw)); - union { - __half2_raw h2r; - vec_fp162 fp16; - } u {static_cast<__half2_raw>(value)}; - u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16); - return static_cast<__half2>(u.h2r); -#else - static_assert(sizeof(__half2_raw) == sizeof(unsigned int)); - union u_hold { - __half2_raw h2r; - unsigned int u32; - }; - u_hold old_val, new_val; - old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - do { - new_val.h2r = __hadd2(old_val.h2r, value); - } while (!__hip_atomic_compare_exchange_strong( - (unsigned int*)address, &old_val.u32, new_val.u32, - __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); - return old_val.h2r; -#endif -} -#define ATOMICADD preview_unsafeAtomicAdd +#define ATOMICADD unsafeAtomicAdd #define NATIVE_ZERO_BF16 __float2bfloat16(0.0f) #else #define ATOMICADD atomicAdd diff --git a/aten/src/ATen/native/cuda/LogAddExpKernel.cu b/aten/src/ATen/native/cuda/LogAddExpKernel.cu index 7b8b5b5bb2032..910d3c1cddc93 100644 --- a/aten/src/ATen/native/cuda/LogAddExpKernel.cu +++ b/aten/src/ATen/native/cuda/LogAddExpKernel.cu @@ -2,18 +2,250 @@ #include #include #include +#include +#include +#include #include #include #include #include +#include + +#include +#include // NOTE: CUDA on Windows requires that the enclosing function // of a __device__ lambda not have internal linkage. namespace at::native { +// custom min and max to be used in logaddexp for complex arguments +template +__host__ __device__ c10::complex _logaddexp_minmax(const c10::complex& x, const c10::complex& y) { + scalar_t xr = std::real(x); + scalar_t yr = std::real(y); + if (::isnan(yr) || (::isnan(std::imag(y)))) { + return y; + } else if (::isnan(xr) || (::isnan(std::imag(x)))) { + return x; + } else if (min) { // min + return (xr < yr) ? x : y; + } else { // max + return (xr >= yr) ? x : y; + } +} + +template +__host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) { + // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp + // Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM + const auto isnan_x = at::_isnan(x); + const auto isnan_y = at::_isnan(y); + scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y)); + scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y)); + if (min != max || ::isfinite(min)) { + // nan will be propagated here + return ::log1p(std::exp(min - max)) + max; + } else { + // special case to correctly handle infinite cases + return x; + } +} + +template +__host__ __device__ c10::complex _fast_build_exp(const c10::complex& x) { + // complex exponential function, but implemented manually to get fast compilation time + // this function only handles the case where the x is finite (not inf nor nan) + const auto xreal = std::real(x); + const auto ximag = std::imag(x); + const auto exp_x_abs = std::exp(xreal); + auto exp_x_real = exp_x_abs * std::cos(ximag); + auto exp_x_imag = exp_x_abs * std::sin(ximag); + return {exp_x_real, exp_x_imag}; +} + +template +__host__ __device__ c10::complex _fast_build_exp_inf(const c10::complex& x) { + // complex exponential function, but implemented manually to get fast compilation time + // this function only handles the case where the real part of x is infinite + const auto ximag = std::imag(x); + constexpr auto exp_x_abs = std::numeric_limits::infinity(); + if (!::isfinite(ximag)) { // add this to make consitent with std::exp(x+yi) + return {exp_x_abs, std::numeric_limits::quiet_NaN()}; + } + const auto sin = std::sin(ximag); + const auto cos = std::cos(ximag); + // special case if the angle is exactly the multiple of pi/2 + auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos; + auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin; + return {exp_x_real, exp_x_imag}; +} + +template +__host__ __device__ c10::complex _log_add_exp_helper(const c10::complex& x, const c10::complex& y) { + c10::complex min = _logaddexp_minmax(x, y); + c10::complex max = _logaddexp_minmax(x, y); + scalar_t min_real = std::real(min); + scalar_t max_real = std::real(max); + + if (::isnan(min_real) || ::isnan(std::imag(min))) { + // handling the "infectious" NaNs + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } + else if ((!::isfinite(min_real)) && (min_real == max_real)) { + if (min_real < 0) { + // handle the -inf case, the imaginary part here does not really matter as the exp(value) + // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined. + // It does not matter if we're taking the exp of this value + return min; + } else { + // handle the +inf case, we don't need the special precision for log1p for small values + // and to avoid producing nan in case of real(max) == real(min) == +inf + const auto exp_min = _fast_build_exp_inf(min); + const auto exp_max = _fast_build_exp_inf(max); + return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log + } + } else { + const auto minmax = min - max; + c10::complex exp_minmax; + if (!::isfinite(minmax.real())) { + exp_minmax = minmax.real() < 0 ? c10::complex{0.0, 0.0} : _fast_build_exp_inf(minmax); + } else { + exp_minmax = _fast_build_exp(minmax); + } + return ::log1p(exp_minmax) + max; + } +} + +// Complex logaddexp jiterator string +const auto logaddexp_complex_string = jiterator_stringify( + template + std::complex log1p(const std::complex& z) + { + using complex_t = std::complex; + T x = z.real(); + T y = z.imag(); + T zabs = abs(z); + T theta = atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return complex_t(x, theta); + } + return complex_t(T(0.5) * std::log1p(r), theta); + } else { + T z0 = std::hypot(x + 1, y); + return complex_t(log(z0), theta); + } + } + + // separated _logaddexp_minmax into 2 different functions for jiterator_string + template + std::complex logaddexp_min(const std::complex& x, const std::complex& y) { + T xr = x.real(); + T yr = y.real(); + if (isnan(yr) || isnan(y.imag())) { + return y; + } else if (isnan(xr) || isnan(x.imag())) { + return x; + } else { + return (xr < yr) ? x : y; + } + } + + template + std::complex logaddexp_max(const std::complex& x, const std::complex& y) { + T xr = x.real(); + T yr = y.real(); + if (isnan(yr) || isnan(y.imag())) { + return y; + } else if (isnan(xr) || isnan(x.imag())) { + return x; + } else { + return (xr >= yr) ? x : y; + } + } + + template + std::complex fast_build_exp(const std::complex& x) { + const auto xreal = x.real(); + const auto ximag = x.imag(); + const auto exp_x_abs = exp(xreal); + auto exp_x_real = exp_x_abs * cos(ximag); + auto exp_x_imag = exp_x_abs * sin(ximag); + return std::complex(exp_x_real, exp_x_imag); + } + + template + std::complex fast_build_exp_inf(const std::complex& x) { + using complex_t = std::complex; + const auto ximag = x.imag(); + const T exp_x_abs = INFINITY; + if (!isfinite(ximag)) { + return complex_t(exp_x_abs, NAN); + } + const auto sin_val = sin(ximag); + const auto cos_val = cos(ximag); + auto exp_x_real = (cos_val == T(0)) ? T(0) : exp_x_abs * cos_val; + auto exp_x_imag = (sin_val == T(0)) ? T(0) : exp_x_abs * sin_val; + return complex_t(exp_x_real, exp_x_imag); + } + + template + complex_t logaddexp_complex(complex_t x, complex_t y) { + using T = typename complex_t::value_type; + complex_t min_val = logaddexp_min(x, y); + complex_t max_val = logaddexp_max(x, y); + T min_real = min_val.real(); + T max_real = max_val.real(); + + if (isnan(min_real) || isnan(min_val.imag())) { + return complex_t(NAN, NAN); + } + else if ((!isfinite(min_real)) && (min_real == max_real)) { + if (min_real < T(0)) { + return min_val; + } else { + const auto exp_min = fast_build_exp_inf(min_val); + const auto exp_max = fast_build_exp_inf(max_val); + return log1p(exp_min + exp_max - complex_t(1, 0)); + } + } else { + const auto minmax = min_val - max_val; + complex_t exp_minmax; + if (!isfinite(minmax.real())) { + exp_minmax = (minmax.real() < T(0)) ? complex_t(0, 0) : fast_build_exp_inf(minmax); + } else { + exp_minmax = fast_build_exp(minmax); + } + return log1p(exp_minmax) + max_val; + } + } +); + +constexpr char logaddexp_complex_name[] = "logaddexp_complex"; void logaddexp_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( + if (at::isComplexType(iter.dtype())) { +#if AT_USE_JITERATOR() + AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() { + jitted_gpu_kernel< + /*name=*/logaddexp_complex_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/2>(iter, logaddexp_complex_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(at::ScalarType::ComplexHalf, iter.dtype(), "logaddexp_cuda", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t { + const auto a = static_cast(a_); + const auto b = static_cast(b_); + return static_cast(_log_add_exp_helper(a, b)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "logaddexp_cuda", [&]() { @@ -29,6 +261,7 @@ void logaddexp_kernel_cuda(TensorIteratorBase& iter) { } }); }); + } } void logaddexp2_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index ac12b812c0670..4ff61f71f2b61 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -1101,6 +1101,19 @@ _scaled_mxfp8_mxfp8( return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); } +void +_check_mxfp4_support() { +#ifndef USE_ROCM + auto dprops = at::cuda::getCurrentDeviceProperties(); + // Only on B200 GPUs + TORCH_CHECK_NOT_IMPLEMENTED( + // B200 = 10.0, B300 = 10.3 + dprops->major == 10, + "MXFP4 scaling only supported in CUDA for B200/B300" + ); +#endif +} + Tensor& _scaled_mxfp4_mxfp4( @@ -1113,6 +1126,7 @@ _scaled_mxfp4_mxfp4( #if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)) TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only"); #else + _check_mxfp4_support(); // Restrictions: // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index 40afa15b4f700..f350b0137b05e 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -91,26 +91,31 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { #include #endif -template -void computeRepeatIndices(const index_t* repeat_ptr, - const int64_t* cumsum_ptr, - index_t* result_ptr, - int64_t size, - int64_t result_size) { - id repeatBuffer = reinterpret_cast>(repeat_ptr); - id cumsumBuffer = reinterpret_cast>(cumsum_ptr); - id resultBuffer = reinterpret_cast>(result_ptr); - TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer); - +Tensor repeat_interleave_mps(const Tensor& repeat, std::optional output_size) { + TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat"); std::string scalar_type; - if constexpr (std::is_same_v) { + if (repeat.scalar_type() == kInt) { scalar_type = "int32_t"; - } else if constexpr (std::is_same_v) { + } else if (repeat.scalar_type() == kLong) { scalar_type = "int64_t"; } else { - TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type"); + TORCH_CHECK(false, "repeats has to be Long or Int tensor"); + } + if (repeat.size(0) == 0) { + return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + Tensor repeat_ = repeat.contiguous(); + Tensor cumsum = repeat.cumsum(0); + int64_t total = 0; + if (output_size.has_value()) { + total = output_size.value(); + } else { + total = cumsum[-1].item(); + TORCH_CHECK((repeat >= 0).all().item(), "repeats can not be negative"); } + auto result = at::empty({total}, repeat.options()); + MPSStream* mpsStream = getCurrentMPSStream(); dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { @@ -121,20 +126,13 @@ void computeRepeatIndices(const index_t* repeat_ptr, getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false); [computeEncoder setComputePipelineState:pipelineState]; - mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size); - mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size); + mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0)); + mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0)); getMPSProfiler().endProfileKernel(pipelineState); } }); -} - -Tensor repeat_interleave_mps(const Tensor& repeat, std::optional output_size) { - Tensor output; - AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() { - output = repeat_interleave_common>(repeat, output_size); - }); - return output; + return result; } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 90371fd8745c8..ed659bddd65cc 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -5,6 +5,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -89,13 +90,21 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& auto clamp_shape = clamp_opt->sizes(); auto input_shape = input_t.sizes(); - TORCH_CHECK(num_clamp_dims <= num_input_dims, - op_name + ": clamp tensor number of dims must not be greater than that of input tensor") + if (num_clamp_dims > num_input_dims) { + auto leading_dims = num_clamp_dims - num_input_dims; + for (int64_t i = 0; i < leading_dims; ++i) { + TORCH_CHECK(clamp_shape[i] == 1, + op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor"); + } + } - for (int i = 0; i < num_clamp_dims; i++) + auto clamp_idx = num_clamp_dims - 1; + auto input_idx = num_input_dims - 1; + auto common_dims = std::min(num_clamp_dims, num_input_dims); + for (int64_t i = 0; i < common_dims; ++i) // One of the indices is allowed to be 1; will be handled by broadcast - TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] || - clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1, + TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 || + input_shape[input_idx - i] == 1, op_name + ": clamp tensor trailing shape must match input tensor") } } @@ -136,9 +145,6 @@ static void clamp_tensor_out_mps(const Tensor& input_t, auto result_type = output_t.scalar_type(); - IntArrayRef new_min_shape; - IntArrayRef new_max_shape; - auto num_min_dims = min_opt->dim(); auto num_max_dims = max_opt->dim(); auto num_input_dims = input_t.dim(); @@ -146,24 +152,32 @@ static void clamp_tensor_out_mps(const Tensor& input_t, std::vector new_min_arr(num_input_dims); std::vector new_max_arr(num_input_dims); - if (has_min && num_min_dims < num_input_dims) { - fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes()); - new_min_shape = IntArrayRef(new_min_arr); - } - - if (has_max && num_max_dims < num_input_dims) { - fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes()); - new_max_shape = IntArrayRef(new_max_arr); - } - Tensor min_opt_tensor; Tensor max_opt_tensor; + auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref, + int64_t num_clamp_dims, + std::vector& new_shape_storage) -> Tensor { + IntArrayRef clamp_shape = clamp_tensor_ref->sizes(); + bool requires_view = false; + + if (num_clamp_dims > num_input_dims) { + clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims); + requires_view = true; + } else if (num_clamp_dims < num_input_dims) { + fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape); + clamp_shape = IntArrayRef(new_shape_storage); + requires_view = true; + } + + return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref; + }; + if (has_min) { - min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt; + min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr); } if (has_max) { - max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt; + max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr); } @autoreleasepool { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 98873abe0c499..9a1c7c790afaa 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4225,7 +4225,7 @@ MTIA: mm_out_mtia MPS: mm_out_mps XPU: mm_out_xpu - SparseCPU, SparseCUDA: _sparse_mm_out + SparseCPU, SparseCUDA, SparseMPS: _sparse_mm_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out - func: mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 81b3ce90b36bf..a522e7ab76cf4 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu + ${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp diff --git a/aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp b/aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp new file mode 100644 index 0000000000000..535bb3d1cc2ea --- /dev/null +++ b/aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp @@ -0,0 +1,77 @@ +#include + +#include +#include +#include + +#include +#include +#include + +// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace +// to verify that the data race fix is working correctly + +TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) { + if (!at::cuda::is_available()) { + return; + } + + constexpr int num_accessor_threads = 15; + constexpr int num_clear_threads = 5; + constexpr int iterations_per_thread = 50; + + std::atomic stop{false}; + std::atomic error_count{0}; + std::vector threads; + threads.reserve(num_accessor_threads + num_clear_threads); + + // Launch accessor threads + for (int i = 0; i < num_accessor_threads; ++i) { + threads.emplace_back([&stop, &error_count]() { + try { + at::cuda::CUDAGuard device_guard(0); + + while (!stop.load(std::memory_order_relaxed)) { + const auto handle = at::cuda::getCurrentCUDABlasHandle(); + const auto workspace = at::cuda::getCUDABlasLtWorkspace(); + + if (handle == nullptr || workspace == nullptr) { + error_count++; + } + } + } catch (const std::exception& e) { + error_count++; + } + }); + } + + // Launch threads that clear workspaces + for (int i = 0; i < num_clear_threads; ++i) { + threads.emplace_back([&error_count]() { + try { + for (int j = 0; j < iterations_per_thread; ++j) { + at::cuda::clearCublasWorkspaces(); + std::this_thread::yield(); + } + } catch (const std::exception& e) { + error_count++; + } + }); + } + + // Let them run for a bit + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + stop.store(true, std::memory_order_relaxed); + + for (auto& thread : threads) { + thread.join(); + } + + EXPECT_EQ(error_count.load(), 0); +} + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + c10::cuda::CUDACachingAllocator::init(1); + return RUN_ALL_TESTS(); +} diff --git a/aten/tools/valgrind.sup b/aten/tools/valgrind.sup index ad5f66e0b0531..585487c4d2be2 100644 --- a/aten/tools/valgrind.sup +++ b/aten/tools/valgrind.sup @@ -10,6 +10,13 @@ ... } +{ + ignore_empty_generic_uninitialised_conditional_jump + Memcheck:Cond + fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE + ... +} + { Cond_cuda Memcheck:Cond diff --git a/benchmarks/dynamo/check_perf_csv.py b/benchmarks/dynamo/check_perf_csv.py index 320a4544f829b..08070dda4444c 100644 --- a/benchmarks/dynamo/check_perf_csv.py +++ b/benchmarks/dynamo/check_perf_csv.py @@ -9,28 +9,61 @@ def check_perf_csv(filename, threshold, threshold_scale): """ Basic performance checking. """ + try: + df = pd.read_csv(filename) + except FileNotFoundError: + print(f"Error: File {filename} not found") + sys.exit(1) - df = pd.read_csv(filename) + effective_threshold = threshold * threshold_scale + print(f"Checking {filename} (speedup threshold >= {effective_threshold:.2f}x)\n") failed = [] for _, row in df.iterrows(): model_name = row["name"] - speedup = row["speedup"] - if speedup < threshold * threshold_scale: - failed.append(model_name) + speedup = float(row["speedup"]) + abs_latency = float(row["abs_latency"]) + compilation_latency = float(row["compilation_latency"]) + compression_ratio = float(row["compression_ratio"]) + eager_peak_mem = float(row["eager_peak_mem"]) + dynamo_peak_mem = float(row["dynamo_peak_mem"]) + + perf_summary = f"{model_name:34} speedup={speedup:.3f}x" + if pd.notna(abs_latency): + perf_summary += f", latency={abs_latency:.1f} ms/iter" + if pd.notna(compilation_latency): + perf_summary += f", compile={compilation_latency:.3f}s" + if pd.notna(compression_ratio): + perf_summary += f", mem_ratio={1 / compression_ratio:.2f}x" + if pd.notna(eager_peak_mem) and pd.notna(dynamo_peak_mem): + perf_summary += ( + f" (eager={eager_peak_mem:.1f} GB, dynamo={dynamo_peak_mem:.1f} GB)" + ) + + if speedup < effective_threshold: + failed.append((model_name, speedup)) - print(f"{model_name:34} {speedup}") + print(perf_summary) if failed: print( textwrap.dedent( f""" - Error {len(failed)} models performance regressed - {" ".join(failed)} + Error {len(failed)} model(s) performance regressed + {" ".join([name for name, _ in failed])} """ ) ) + for name, sp in sorted(failed, key=lambda x: x[1]): + pct_from_target = (sp / effective_threshold - 1.0) * 100.0 + print( + f" - {name}: {sp:.3f}x (< {effective_threshold:.2f}x; {pct_from_target:.1f}% from target)" + ) sys.exit(1) + else: + print( + f"\nAll {len(df)} model(s) passed threshold check (>= {effective_threshold:.2f}x)" + ) if __name__ == "__main__": @@ -44,7 +77,7 @@ def check_perf_csv(filename, threshold, threshold_scale): "-s", type=float, default=1.0, - help="multiple threshold by this value to relax the check", + help="multiply threshold by this value to relax the check", ) args = parser.parse_args() check_perf_csv(args.file, args.threshold, args.threshold_scale) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index a3bd58c4de747..b3484e7196a83 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2379,7 +2379,9 @@ def record_status(accuracy_status, dynamo_start_stats): print( f"Load model outputs from {self.args.compare_model_outputs_with} to compare" ) - saved_result = torch.load(self.args.compare_model_outputs_with) + saved_result = torch.load( + self.args.compare_model_outputs_with, weights_only=False + ) is_bitwise_same = bitwise_same(saved_result, new_result) if not is_bitwise_same: print( diff --git a/c10/core/SafePyObject.h b/c10/core/SafePyObject.h index 1ec0cdb6751e9..bcace0ac358b4 100644 --- a/c10/core/SafePyObject.h +++ b/c10/core/SafePyObject.h @@ -44,7 +44,7 @@ struct C10_API SafePyObject { (*other.pyinterpreter_)->incref(other.data_); } if (data_ != nullptr) { - (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false); + (*pyinterpreter_)->decref(data_); } data_ = other.data_; pyinterpreter_ = other.pyinterpreter_; @@ -53,7 +53,7 @@ struct C10_API SafePyObject { ~SafePyObject() { if (data_ != nullptr) { - (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false); + (*pyinterpreter_)->decref(data_); } } diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 5bc537dbd83c8..040c6abb7d8e2 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -92,13 +92,6 @@ inline bool isComplexType(ScalarType t) { t == ScalarType::ComplexDouble); } -inline bool isQIntType(ScalarType t) { - // Don't forget to extend this when adding new QInt types - return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || - t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || - t == ScalarType::QUInt2x4; -} - inline bool isBitsType(ScalarType t) { return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index a614fc9234c94..00fc03bbd0fcf 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() { TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); } +void StorageImpl::incref_pyobject() const { + // Because intrusive_ptr incref uses relaxed memory order, we need to + // do an acquire fence to ensure that the kHasPyObject bit was + // observed before the load of the PyObject* below. + // NB: This is a no-op on x86/x86-64 + std::atomic_thread_fence(std::memory_order_acquire); + + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->incref(obj); +} + +void StorageImpl::decref_pyobject() const { + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->decref(obj); +} + +bool StorageImpl::try_incref_pyobject() const { + c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); + if (C10_UNLIKELY(!interp)) { + return false; + } + return (*interp)->try_incref(pyobj_slot_); +} + void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) { // Allowlist verification. // Only if the devicetype is in the allowlist, diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index f34a1baed7a48..c7dbd5c1f005b 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { data_ptr_.clear(); } + void incref_pyobject() const override final; + + void decref_pyobject() const override final; + + bool try_incref_pyobject() const override final; + size_t nbytes() const { // OK to do this instead of maybe_as_int as nbytes is guaranteed positive TORCH_CHECK(!size_bytes_is_heap_allocated_); @@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr make_storage_impl( bool resizable, std::optional device_opt); +namespace detail { + +#ifndef C10_MOBILE +template +struct TargetTraits< + T, + std::enable_if_t< + std::is_base_of_v>>> { + static constexpr bool can_have_pyobject = true; +}; +#endif + +} // namespace detail + } // namespace c10 diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index c59524a0932c2..94a7375cc32fb 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -277,7 +277,6 @@ void TensorImpl::release_resources() { if (storage_) { storage_ = {}; } - pyobj_slot_.maybe_destroy_pyobj(); } #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY @@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { } } +void TensorImpl::incref_pyobject() const { + // Because intrusive_ptr incref uses relaxed memory order, we need to + // do an acquire fence to ensure that the kHasPyObject bit was + // observed before the load of the PyObject* below. + // NB: This is a no-op on x86/x86-64 + std::atomic_thread_fence(std::memory_order_acquire); + + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->incref(obj); +} + +void TensorImpl::decref_pyobject() const { + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->decref(obj); +} + +bool TensorImpl::try_incref_pyobject() const { + c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); + if (C10_UNLIKELY(!interp)) { + return false; + } + return (*interp)->try_incref(pyobj_slot_); +} + namespace impl { namespace { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 420ed73e48d21..71a0195dde773 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } + void incref_pyobject() const override final; + + void decref_pyobject() const override final; + + bool try_incref_pyobject() const override final; + private: // See NOTE [std::optional operator usage in CUDA] // We probably don't want to expose this publicly until @@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { friend class C10_TensorImpl_Size_Check_Dummy_Class; }; +namespace detail { + +#ifndef C10_MOBILE +template +struct TargetTraits< + T, + std::enable_if_t>>> { + static constexpr bool can_have_pyobject = true; +}; +#endif + +} // namespace detail + // Note [TensorImpl size constraints] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Changed the size of TensorImpl? If the size went down, good for diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index 8676f0aaf8e0e..52d263fad36c5 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { void incref(PyObject* pyobj) const override {} // do nothing - void decref(PyObject* pyobj, bool has_pyobj_slot) const override { - } // do nothing + void decref(PyObject* pyobj) const override {} // do nothing + + bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override { + return false; + } #define PANIC(m) \ TORCH_INTERNAL_ASSERT( \ @@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { "attempted to call " #m \ " on a Tensor with nontrivial PyObject after corresponding interpreter died") + size_t refcnt(PyObject* pyobj) const override { + PANIC(refcnt); + } + c10::intrusive_ptr detach(const TensorImpl* self) const override { PANIC(detach); } diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index def708c24b802..463b1e520b36e 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -18,6 +18,9 @@ namespace c10 { struct IValue; class OperatorHandle; struct TensorImpl; +namespace impl { +struct PyObjectSlot; +} // namespace impl } // namespace c10 namespace torch::jit { @@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable { // Run Py_INCREF on a PyObject. virtual void incref(PyObject* pyobj) const = 0; - // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call - // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] - virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0; + // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call. + virtual void decref(PyObject* pyobj) const = 0; + // Run PyUnstable_TryIncRef on a PyObject if it's not NULL. + virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0; + // Run Py_REFCNT on a PyObject. + virtual size_t refcnt(PyObject* pyobj) const = 0; // Perform a detach by deferring to the __torch_dispatch__ implementation of // detach, which will also arrange for the PyObject to get copied in this diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp deleted file mode 100644 index 0f1bfb2110747..0000000000000 --- a/c10/core/impl/PyObjectSlot.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include - -namespace c10::impl { - -PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} - -PyObjectSlot::~PyObjectSlot() { - maybe_destroy_pyobj(); -} - -void PyObjectSlot::maybe_destroy_pyobj() { - if (owns_pyobj()) { - TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr); - TORCH_INTERNAL_ASSERT(pyobj_ != nullptr); - (*pyobj_interpreter_.load(std::memory_order_acquire)) - ->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true); - // NB: this destructor can only be entered when there are no - // references to this C++ object (obviously), NOR any references - // to the PyObject (if there are references to the PyObject, - // then the PyObject holds an owning reference to the tensor). - // So it is OK to clear pyobj_ here as it is impossible for it to - // be used again (modulo weak reference races) - pyobj_ = nullptr; // for safety - } -} - -PyInterpreter* PyObjectSlot::pyobj_interpreter() { - return pyobj_interpreter_.load(std::memory_order_acquire); -} - -PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - return reinterpret_cast( - reinterpret_cast(pyobj_) & ~0x1ULL); -} - -PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { - auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); - if (interpreter) { - return *interpreter; - } - TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set"); -} - -bool PyObjectSlot::owns_pyobj() { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - return reinterpret_cast(pyobj_) & 1; -} - -void PyObjectSlot::set_owns_pyobj(bool b) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - pyobj_ = reinterpret_cast( - reinterpret_cast(_unchecked_untagged_pyobj()) | b); -} - -} // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 58b2490eba001..a0633401b3634 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -8,117 +8,58 @@ #include +namespace torch::utils { +class PyObjectPreservation; +} + namespace c10::impl { struct C10_API PyObjectSlot { public: - PyObjectSlot(); - - ~PyObjectSlot(); - - void maybe_destroy_pyobj(); - - // Associate the TensorImpl with the specified PyObject, and, if necessary, - // also tag the interpreter. - // - // NB: This lives in a header so that we can inline away the switch on status - // - // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after - // PyObject if necessary! - void init_pyobj(PyObject* pyobj) { - pyobj_interpreter_.store( - getGlobalPyInterpreter(), std::memory_order_relaxed); - pyobj_ = pyobj; - } + PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} // Query the PyObject interpreter. This may return null if there is no - // interpreter. This is racy! - PyInterpreter* pyobj_interpreter(); - - PyObject* _unchecked_untagged_pyobj() const; - - // Test the interpreter tag. If tagged for the current interpreter, return - // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, - // returns a nullopt. If it is definitely invalid, raises an error. - // - // If `ignore_hermetic_tls` is false and this function is called from a - // hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then - // nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic - // context is ignored, allowing you to check the interpreter tag of a - // nonhermetic PyObject from within a hermetic context. This is necessary - // because there are some cases where the deallocator function of a - // nonhermetic PyObject is called from within a hermetic context, so it must - // be properly treated as a nonhermetic PyObject. - // - // NB: this lives in header so that we can avoid actually creating the - // std::optional + // interpreter. + PyInterpreter* pyobj_interpreter() const { + return pyobj_interpreter_.load(std::memory_order_acquire); + } - // @todo alban: I'm not too sure what's going on here, we can probably delete - // it but it's worthwhile making sure - std::optional check_pyobj(bool ignore_hermetic_tls = false) const { - impl::PyInterpreter* interpreter = - pyobj_interpreter_.load(std::memory_order_acquire); - if (interpreter == nullptr) { - return std::nullopt; - } + PyInterpreter& load_pyobj_interpreter() const { + auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); + TORCH_INTERNAL_ASSERT( + interpreter, "cannot access PyObject for Tensor - no interpreter set"); + return *interpreter; + } - if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { - return std::nullopt; - } else { - return _unchecked_untagged_pyobj(); - } + PyObject* load_pyobj() const { + return pyobj_.load(std::memory_order_acquire); } - PyInterpreter& load_pyobj_interpreter() const; + void store_pyobj(PyObject* obj) { + pyobj_.store(obj, std::memory_order_release); + } - bool owns_pyobj(); + bool has_unique_reference() const { + PyObject* pyobj = load_pyobj(); + return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1; + } - void set_owns_pyobj(bool b); + void clear() { + pyobj_.store(nullptr, std::memory_order_relaxed); + pyobj_interpreter_.store(nullptr, std::memory_order_relaxed); + } private: - // This field contains the interpreter tag for this object. See - // Note [Python interpreter tag] for general context - // - // Note [Memory ordering on Python interpreter tag] - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // What memory_order do we need when accessing this atomic? We don't - // need a single total modification order (as provided by - // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only - // transition from -1 to some positive integer and never changes afterwards. - // Because there is only one modification, it trivially already has a total - // modification order (e.g., we don't need fences or locked instructions on - // x86) - // - // In fact, one could make a reasonable argument that relaxed reads are OK, - // due to the presence of external locking (GIL) to ensure that interactions - // with other data structures are still correctly synchronized, so that - // we fall in the "Single-Location Data Structures" case as described in - // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf - // However, on x86, it doesn't matter if I use acquire or relaxed on the load - // as I get the same assembly in both cases. So I just use the more - // conservative acquire (which will impede compiler optimizations but I don't - // care) + // This is now always the global interpreter if the PyObject is set. + // Maybe we can remove this field some day... std::atomic pyobj_interpreter_; - // This field contains a reference to a PyObject representing this Tensor. - // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new - // PyObject for it and set this field. This field does not have to be - // protected by an atomic as it is only allowed to be accessed when you hold - // the GIL, or during destruction of the tensor. - // - // When a PyObject dies, you are obligated to clear this field - // (otherwise, you will try to use-after-free the pyobj); this currently - // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp - // - // NB: Ordinarily, this should not be a strong reference, as if the - // PyObject owns the Tensor, this would create a reference cycle. - // However, sometimes this ownership flips. To track who owns - // who, this has a single pointer tag indicating whether or not the - // C++ object owns the PyObject (the common case, zero, means PyObject - // owns the C++ object); see _unchecked_untagged_pyobj for raw access - // or check_pyobj for checked access. See references to PyObject - // resurrection in torch/csrc/autograd/python_variable.cpp - PyObject* pyobj_; + // The PyObject representing this Tensor or nullptr. Ownership is managed + // by intrusive_ptr. By the time the PyObjectSlot is destroyed, this + // reference is already dead. + std::atomic pyobj_; + + friend class torch::utils::PyObjectPreservation; }; } // namespace c10::impl diff --git a/c10/util/Exception.cpp b/c10/util/Exception.cpp index cccdb28607141..1928c2c175c7b 100644 --- a/c10/util/Exception.cpp +++ b/c10/util/Exception.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -28,7 +27,7 @@ Error::Error( const void* caller) : Error( str("[enforce fail at ", - c10::filesystem::path(file).filename(), + detail::StripBasename(file), ":", line, "] ", diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp index b95eaec9d3ebb..4bf96b1b6808a 100644 --- a/c10/util/Logging.cpp +++ b/c10/util/Logging.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -479,7 +478,8 @@ MessageLogger::MessageLogger( << std::setfill('0') << " " << std::setw(2) << timeinfo->tm_hour << ":" << std::setw(2) << timeinfo->tm_min << ":" << std::setw(2) << timeinfo->tm_sec << "." << std::setw(9) << ns << " " - << c10::filesystem::path(file).filename() << ":" << line << "] "; + << c10::detail::StripBasename(std::string(file)) << ":" << line + << "] "; } // Output the contents of the stream to the proper channel on destruction. diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 3d5478be90e60..0c8f55f5061ab 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -12,6 +12,10 @@ template class class_; } +namespace torch::utils { +class PyObjectPreservation; +} + namespace c10 { class intrusive_ptr_target; namespace raw { @@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount = constexpr uint64_t kReferenceCountOne = 1; constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32); constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne); +// Indicates whether the object has a PyObject wrapper. +constexpr uint64_t kHasPyObject = (uint64_t(1) << 63); template struct intrusive_target_default_null_type final { @@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) { } inline uint32_t weakcount(uint64_t combined_refcount) { - return static_cast(combined_refcount >> 32); + return static_cast((combined_refcount & ~kHasPyObject) >> 32); +} + +inline bool has_pyobject(uint64_t combined_refcount) { + return (combined_refcount & kHasPyObject) != 0; } // The only requirement for refcount increment is that it happens-before @@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment( return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc; } -inline uint32_t atomic_refcount_increment( - std::atomic& combined_refcount) { - return detail::refcount(atomic_combined_refcount_increment( - combined_refcount, kReferenceCountOne)); -} - inline uint32_t atomic_weakcount_increment( std::atomic& combined_refcount) { return detail::weakcount(atomic_combined_refcount_increment( @@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement( combined_refcount, kWeakReferenceCountOne)); } +template +struct TargetTraits { + static constexpr bool can_have_pyobject = false; +}; + } // namespace detail /** @@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target { // we can atomically operate on both at the same time for performance // and defined behaviors. // + // Note [PyObject preservation for Tensor and Storages] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // intrusive_ptr has special support for preserving PyObject wrappers + // for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of + // the combined_refcount_ is used to indicate whether the object has a + // PyObject wrapper. + // + // - The PyObject, if it exists, holds a strong reference to the + // intrusive_ptr_target. + // + // - When the refcount goes from 1 to 2, we incref the PyObject. + // + // - When the refcount goes from 2 to 1, we decref the PyObject. + // + // In other words, the intrusive_ptr keeps the PyObject alive as long as there + // are other C++ references to the intrusive_ptr_target. + mutable std::atomic combined_refcount_; static_assert(sizeof(std::atomic) == 8); static_assert(alignof(std::atomic) == 8); @@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target { template friend struct ExclusivelyOwnedTensorTraits; + friend class torch::utils::PyObjectPreservation; + protected: // protected destructor. We never want to destruct intrusive_ptr_target* // directly. @@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target { */ virtual void release_resources() {} + /** + * These two methods are called when the refcount transitions between one + * and two and the object has a PyObject wrapper. + */ + virtual void incref_pyobject() const {} + virtual void decref_pyobject() const {} + virtual bool try_incref_pyobject() const { + return false; + } + uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const { return detail::refcount(combined_refcount_.load(order)); } @@ -265,6 +303,19 @@ class C10_API intrusive_ptr_target { } }; +namespace detail { + +#ifndef C10_MOBILE +template <> +struct TargetTraits { + // A generic intrusive_ptr may actually be a TensorImpl + // or StorageImpl, so we have to allow for PyObject support. + static constexpr bool can_have_pyobject = true; +}; +#endif + +} // namespace detail + template class weak_intrusive_ptr; @@ -314,18 +365,34 @@ class intrusive_ptr final { void retain_() { if (target_ != NullType::singleton()) { - uint32_t new_refcount = - detail::atomic_refcount_increment(target_->combined_refcount_); + uint64_t combined = detail::atomic_combined_refcount_increment( + target_->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( new_refcount != 1, "intrusive_ptr: Cannot increase refcount after it reached zero."); + + if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 1 to 2, we need to incref the + // PyObject. In other words, we need to ensure that the PyObject stays + // alive now that we have a C++ reference to this object in addition to + // the PyObject itself. + if (C10_UNLIKELY( + detail::has_pyobject(combined) && + detail::refcount(combined) == 2)) { + target_->incref_pyobject(); + } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !detail::has_pyobject(combined), + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } } } void reset_() noexcept { if (target_ != NullType::singleton()) { - if (target_->combined_refcount_.load(std::memory_order_acquire) == - detail::kUniqueRef) { + if (is_uniquely_owned()) { // Both counts are 1, so there are no weak references and // we are releasing the last strong reference. No other // threads can observe the effects of this target_ deletion @@ -337,9 +404,10 @@ class intrusive_ptr final { auto combined_refcount = detail::atomic_combined_refcount_decrement( target_->combined_refcount_, detail::kReferenceCountOne); - if (detail::refcount(combined_refcount) == 0) { - bool should_delete = - (combined_refcount == detail::kWeakReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined_refcount); + bool has_pyobject = detail::has_pyobject(combined_refcount); + if (new_refcount == 0) { + bool should_delete = detail::weakcount(combined_refcount) == 1; // See comment above about weakcount. As long as refcount>0, // weakcount is one larger than the actual number of weak references. // So we need to decrement it here. @@ -356,6 +424,18 @@ class intrusive_ptr final { if (should_delete) { delete target_; } + } else if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 2 to 1, we need to decref the + // PyObject. In other words, we don't want to keep the PyObject alive if + // there are no C++ references to this object other than the PyObject + // itself. + if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) { + target_->decref_pyobject(); + } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !has_pyobject, + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); } } } @@ -522,6 +602,16 @@ class intrusive_ptr final { return use_count() == 1; } + /** + * Stronger than unique() in that it must not have any weakrefs as well. + */ + bool is_uniquely_owned() const noexcept { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton()); + uint64_t combined = + target_->combined_refcount_.load(std::memory_order_acquire); + return (combined & ~detail::kHasPyObject) == detail::kUniqueRef; + } + /** * Returns an owning (!) pointer to the underlying object and makes the * intrusive_ptr instance invalid. That means the refcount is not decreased. @@ -932,6 +1022,7 @@ class weak_intrusive_ptr final { if (target_ == NullType::singleton()) { return intrusive_ptr(); } else { + bool increfed = false; auto combined_refcount = target_->combined_refcount_.load(std::memory_order_relaxed); do { @@ -940,12 +1031,31 @@ class weak_intrusive_ptr final { // Return nullptr. return intrusive_ptr(); } + if constexpr (detail::TargetTraits::can_have_pyobject) { + if (detail::has_pyobject(combined_refcount) && + detail::refcount(combined_refcount) == 1 && !increfed) { + // Object has a python wrapper with no other C++ references. + // We need to to incref the Python object before we acquire a + // strong reference to the C++ object to avoid a situation + // where the Python object is deallocated concurrently. + if (!target_->try_incref_pyobject()) { + return intrusive_ptr(); + } + increfed = true; + } + } } while (!target_->combined_refcount_.compare_exchange_weak( combined_refcount, combined_refcount + detail::kReferenceCountOne, std::memory_order_acquire, std::memory_order_relaxed)); + if constexpr (detail::TargetTraits::can_have_pyobject) { + if (increfed && detail::refcount(combined_refcount) != 1) { + target_->decref_pyobject(); + } + } + return intrusive_ptr( target_, raw::DontIncreaseRefcount{}); } @@ -1060,7 +1170,18 @@ namespace intrusive_ptr { // NullType::singleton to this function inline void incref(intrusive_ptr_target* self) { if (self) { - detail::atomic_refcount_increment(self->combined_refcount_); + uint64_t combined = detail::atomic_combined_refcount_increment( + self->combined_refcount_, detail::kReferenceCountOne); + +#ifndef C10_MOBILE + if (C10_UNLIKELY( + detail::has_pyobject(combined) && + detail::refcount(combined) == 2)) { + self->incref_pyobject(); + } +#else + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined)); +#endif } } diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index ba748449b29e3..3bd9eff0fee63 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -15,6 +15,8 @@ using namespace c10::CachingDeviceAllocator; // newly allocated memory with 512-byte alignment. constexpr size_t kDeviceAlignment = 512; +class XPUAllocator; + namespace { using stream_set = ska::flat_hash_set; @@ -23,14 +25,19 @@ typedef bool (*Comparison)(const Block*, const Block*); bool BlockComparatorSize(const Block* a, const Block* b); bool BlockComparatorAddress(const Block* a, const Block* b); +struct PrivatePool; + struct BlockPool { - BlockPool(bool small) + BlockPool(bool small, PrivatePool* private_pool = nullptr) : blocks(BlockComparatorSize), unmapped(BlockComparatorAddress), - is_small(small) {} + is_small(small), + owner_PrivatePool(private_pool) {} + std::set blocks; std::set unmapped; const bool is_small; + PrivatePool* owner_PrivatePool; }; struct ExpandableSegment; @@ -349,6 +356,43 @@ struct AllocParams { StatTypes stat_types = {}; }; +// Internal implementation that manages actual memory blocks. +// high level MemPool interface wraps PrivatePool via MempoolId. +struct PrivatePool { + PrivatePool(MempoolId_t id, XPUAllocator* allocator = nullptr) + : id(std::move(id)), + allocator_(allocator), + large_blocks(/*small=*/false, this), + small_blocks(/*small=*/true, this) {} + PrivatePool(const PrivatePool&) = delete; + PrivatePool(PrivatePool&&) = delete; + PrivatePool& operator=(const PrivatePool&) = delete; + PrivatePool& operator=(PrivatePool&&) = delete; + ~PrivatePool() = default; + + // default Mempool when no Mempool is specified + MempoolId_t id{0, 0}; + // Number of live graphs using this pool + int use_count{1}; + // Number of unfreed allocations made for this pool. When use_count and + // allocation_count drop to zero, we can delete this PrivatePool from + // graph_pools. + int allocation_count{0}; + XPUAllocator* allocator_; + BlockPool large_blocks; + BlockPool small_blocks; + + public: + XPUAllocator* allocator() { + return allocator_; + } +}; +struct MempoolIdHash { + std::size_t operator()(const MempoolId_t& mempool_id) const noexcept { + return mempool_id.first != 0 ? mempool_id.first : mempool_id.second; + } +}; + } // anonymous namespace class DeviceCachingAllocator { @@ -365,6 +409,13 @@ class DeviceCachingAllocator { bool set_fraction = false; std::vector expandable_segments; std::vector devices_with_peer_access; // reserved + std::vector>> + captures_underway; + ska::flat_hash_map, MempoolIdHash> + graph_pools; + // Pools no longer referenced by any graph. + ska::flat_hash_map + graph_pools_freeable; size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { if (!src || src->allocated || src->event_count > 0 || @@ -463,7 +514,22 @@ class DeviceCachingAllocator { } } - BlockPool& get_pool(size_t size) { + BlockPool& get_pool(size_t size, sycl::queue* queue) { + if (C10_UNLIKELY(!captures_underway.empty())) { + for (auto& entry : captures_underway) { + // lookup for mempool id matching current capture graph + if (entry.second(queue)) { + auto it1 = graph_pools.find(entry.first); + // lookup mempool + TORCH_INTERNAL_ASSERT(it1 != graph_pools.end()); + if (size <= kSmallSize) { + return it1->second->small_blocks; + } else { + return it1->second->large_blocks; + } + } + } + } if (size < kSmallSize) { return small_blocks; } else { @@ -669,6 +735,10 @@ class DeviceCachingAllocator { if (!ptr) { return false; } + + if (p.pool->owner_PrivatePool) { + p.pool->owner_PrivatePool->allocation_count++; + } p.block = new Block(device, p.queue(), size, p.pool, ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { stats.reserved_bytes[stat_type].increase(size); @@ -677,11 +747,14 @@ class DeviceCachingAllocator { return true; } - void synchronize_and_free_events() { + void synchronize_and_free_events(PrivatePool* pool = nullptr) { for (auto& xe : xpu_events) { for (auto& e : xe.second) { auto event = e.first; auto* block = e.second; + if (pool && block->pool->owner_PrivatePool != pool) { + continue; + } event.wait(); block->event_count--; if (block->event_count == 0) { @@ -785,6 +858,13 @@ class DeviceCachingAllocator { for_each_selected_stat_type(stat_types, [&](size_t stat_type) { stats.reserved_bytes[stat_type].decrease(unmapped.size); }); + + if (block->pool->owner_PrivatePool) { + // The Freed block belonged to a XPU graph's PrivatePool. + TORCH_INTERNAL_ASSERT( + block->pool->owner_PrivatePool->allocation_count > 0); + block->pool->owner_PrivatePool->allocation_count--; + } } void release_blocks(BlockPool& pool) { @@ -812,13 +892,41 @@ class DeviceCachingAllocator { } } - bool release_cached_blocks() { - synchronize_and_free_events(); - // See Note [Safe to Free Blocks on BlockPool] - c10::xpu::syncStreamsOnDevice(device_index); + bool release_cached_blocks(MempoolId_t mempool_id) { + if (mempool_id.first == 0 && mempool_id.second == 0 && + captures_underway.empty()) { + synchronize_and_free_events(); + // See Note [Safe to Free Blocks on BlockPool] + c10::xpu::syncStreamsOnDevice(device_index); + + release_blocks(large_blocks); + release_blocks(small_blocks); + } - release_blocks(large_blocks); - release_blocks(small_blocks); + for (auto it = graph_pools_freeable.begin(); + it != graph_pools_freeable.end();) { + if (mempool_id.first != 0 || mempool_id.second != 0) { + if (it->first == mempool_id) { + // If there is an active mempool, we sync only the events + // associated with the pool + synchronize_and_free_events(it->second); + } else { + // otherwise we move on + ++it; + continue; + } + } + TORCH_INTERNAL_ASSERT(it->second->use_count == 0); + release_blocks(it->second->small_blocks); + release_blocks(it->second->large_blocks); + if (it->second->allocation_count == 0) { + auto erase_count = graph_pools.erase(it->first); + TORCH_INTERNAL_ASSERT(erase_count == 1); + it = graph_pools_freeable.erase(it); + } else { + ++it; + } + } return true; } @@ -903,6 +1011,30 @@ class DeviceCachingAllocator { } } + void create_or_incref_pool( + MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr) { + auto it = graph_pools.find(mempool_id); + if (it == graph_pools.end()) { + // mempool_id does not reference an existing pool. + // Make a new pool for XPU graph capture or memory pool usage. + graph_pools.emplace( + mempool_id, std::make_unique(mempool_id, allocator)); + } else { + // mempool_id references an existing pool, which the current XPU graph + // capture will share. + TORCH_INTERNAL_ASSERT(it->second->use_count > 0); + TORCH_INTERNAL_ASSERT(allocator == nullptr); + it->second->use_count++; + } + } + + PrivatePool* get_private_pool(MempoolId_t mempool_id) { + auto it = graph_pools.find(mempool_id); + TORCH_INTERNAL_ASSERT(it != graph_pools.end()); + return it->second.get(); + } + public: DeviceCachingAllocator(DeviceIndex device_index) : large_blocks(/* small */ false), @@ -911,9 +1043,11 @@ class DeviceCachingAllocator { Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) { std::scoped_lock lock(mutex); - process_events(); + if (C10_LIKELY(captures_underway.empty())) { + process_events(); + } size_t size = round_size(orig_size); - auto& pool = get_pool(size); + auto& pool = get_pool(size, &queue); const size_t alloc_size = get_allocation_size(size); AllocParams params(device, size, &queue, &pool, alloc_size); params.stat_types = get_stat_types_for_pool(pool); @@ -923,7 +1057,7 @@ class DeviceCachingAllocator { // Can't reuse an existing block, try to get a new one. if (!block_found) { block_found = alloc_block(params, false) || - (release_cached_blocks() && alloc_block(params, true)); + (release_cached_blocks({0, 0}) && alloc_block(params, true)); } if (!block_found) { const auto& raw_device = c10::xpu::get_raw_device(device); @@ -1016,9 +1150,9 @@ class DeviceCachingAllocator { block->stream_uses.insert(stream); } - void emptyCache() { + void emptyCache(MempoolId_t mempool_id) { std::scoped_lock lock(mutex); - release_cached_blocks(); + release_cached_blocks(mempool_id); } DeviceStats getStats() { @@ -1172,9 +1306,9 @@ class XPUAllocator : public DeviceAllocator { } } - void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { + void emptyCache(MempoolId_t mempool_id) override { for (auto& da : device_allocators) { - da->emptyCache(); + da->emptyCache(mempool_id); } } @@ -1290,8 +1424,8 @@ void init(DeviceIndex device_count) { return allocator.init(device_count); } -void emptyCache() { - return allocator.emptyCache(); +void emptyCache(MempoolId_t mempool_id) { + return allocator.emptyCache(mempool_id); } void resetPeakStats(DeviceIndex device) { diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index b1f41a103f8f8..bbb20a5b2ecdf 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -10,7 +10,7 @@ C10_XPU_API Allocator* get(); C10_XPU_API void init(DeviceIndex device_count); -C10_XPU_API void emptyCache(); +C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0}); C10_XPU_API void resetPeakStats(DeviceIndex device); diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index d6c2bfd39c43a..0193a6bc180f1 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -773,8 +773,20 @@ void PyTorchStreamWriter::writeRecord( bool compress) { AT_ASSERT(!finalized_); AT_ASSERT(!archive_name_plus_slash_.empty()); - TORCH_INTERNAL_ASSERT( - files_written_.count(name) == 0, "Tried to serialize file twice: ", name); + if (files_written_.count(name) > 0) { + // Allow multiple writes for triton binaries + bool is_triton_extension = + c10::ends_with(name, ".so") || + c10::ends_with(name, ".cubin") || + c10::ends_with(name, ".hsaco"); + + if (is_triton_extension) { + LOG(WARNING) << "File '" << name << "' is being serialized multiple times"; + return; + } + + TORCH_INTERNAL_ASSERT(false, "Tried to serialize file twice: ", name); + } if (name == kSerializationIdRecordName && serialization_id_.empty()) { // In case of copying records from another file, skip writing a different // serialization_id than the one computed in this writer. diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index c52fe1d2443b6..55d03b7c46320 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -118,11 +118,6 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") endif() endif() - if("${_arch}" STREQUAL "121a") - if(_existing_arch_flags MATCHES ".*compute_120.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a") - endif() - endif() endforeach() list(JOIN _file_compile_flags " " _file_compile_flags) @@ -131,7 +126,7 @@ if(INTERN_BUILD_ATEN_OPS) _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" - "89;90a;100a;103a;120a;121a") + "89;90a;100a;103a;120a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" "90a") diff --git a/test/cpp/aoti_abi_check/test_scalartype.cpp b/test/cpp/aoti_abi_check/test_scalartype.cpp index c299d58664c8e..6df242b5a4cec 100644 --- a/test/cpp/aoti_abi_check/test_scalartype.cpp +++ b/test/cpp/aoti_abi_check/test_scalartype.cpp @@ -101,3 +101,14 @@ TEST(TestScalarType, toUnderlying) { AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK); #undef DEFINE_CHECK } + +TEST(TestScalarType, isQIntType) { + using torch::headeronly::isQIntType; + using torch::headeronly::ScalarType; +#define DEFINE_CHECK(_, name) EXPECT_TRUE(isQIntType(ScalarType::name)); + AT_FORALL_QINT_TYPES(DEFINE_CHECK); +#undef DEFINE_CHECK +#define DEFINE_CHECK(_, name) EXPECT_FALSE(isQIntType(ScalarType::name)); + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK); +#undef DEFINE_CHECK +} diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 58f87717844de..66295d0380629 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -15,7 +15,7 @@ namespace jit { TEST(CustomOperatorTest, InferredSchema) { torch::RegisterOperators reg( "foo::bar", [](double a, at::Tensor b) { return a + b; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -43,8 +43,7 @@ TEST(CustomOperatorTest, ExplicitSchema) { "foo::bar_with_schema(float a, Tensor b) -> Tensor", [](double a, at::Tensor b) { return a + b; }); - auto& ops = - getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -77,7 +76,7 @@ TEST(CustomOperatorTest, ListParameters) { torch::List> complexdoubles, torch::List tensors) { return floats; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -123,7 +122,7 @@ TEST(CustomOperatorTest, ListParameters2) { "foo::lists2(Tensor[] tensors) -> Tensor[]", [](torch::List tensors) { return tensors; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -213,7 +212,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) { }, aliasAnalysisFromSchema())}); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); ASSERT_EQ(ops.size(), 0); } @@ -232,7 +231,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) { }, aliasAnalysisFromSchema())}); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp index 72c78984b5215..c266660c232f7 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/kernel.cpp @@ -203,3 +203,42 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) m.impl("my_reshape", TORCH_BOX(&my_reshape)); m.impl("my_view", TORCH_BOX(&my_view)); } + +uint64_t get_any_data_ptr(Tensor t, bool mutable_) { + if (mutable_) { + return reinterpret_cast(t.mutable_data_ptr()); + } else { + return reinterpret_cast(t.const_data_ptr()); + } +} + +uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable_) { +#define DEFINE_CASE(T, name) \ + case torch::headeronly::ScalarType::name: { \ + if (mutable_) { \ + return reinterpret_cast(t.mutable_data_ptr()); \ + } else { \ + return reinterpret_cast(t.const_data_ptr()); \ + } \ + } + switch (dtype) { + // per aten/src/ATen/templates/TensorMethods.cpp: + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) + DEFINE_CASE(uint16_t, UInt16) + DEFINE_CASE(uint32_t, UInt32) + DEFINE_CASE(uint64_t, UInt64) + default: + return 0; + } +#undef DEFINE_CASE +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("get_any_data_ptr(Tensor t, bool mutable_) -> int"); + m.def("get_template_any_data_ptr(Tensor t, ScalarType dtype, bool mutable_) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("get_any_data_ptr", TORCH_BOX(&get_any_data_ptr)); + m.impl("get_template_any_data_ptr", TORCH_BOX(&get_template_any_data_ptr)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index 42c437ebf755e..db1a4fd43033c 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -197,3 +197,29 @@ def my_view(t, size) -> Tensor: Returns: Tensor - tensor with new view """ return torch.ops.libtorch_agnostic_2_10.my_view.default(t, size) + + +def get_any_data_ptr(t, mutable) -> int: + """ + Return data pointer value of the tensor. + Args: + t: Input tensor + mutable: whether data pointer qualifier is mutable or const + Returns: int - pointer value + """ + return torch.ops.libtorch_agnostic_2_10.get_any_data_ptr.default(t, mutable) + + +def get_template_any_data_ptr(t, dtype, mutable) -> int: + """ + Return data pointer value of the tensor iff it has dtype. + Args: + t: Input tensor + dtype: Input dtype + mutable: whether data pointer qualifier is mutable or const + Returns: int - pointer value + Raises RuntimeError when t.dtype() != dtype. + """ + return torch.ops.libtorch_agnostic_2_10.get_template_any_data_ptr.default( + t, dtype, mutable + ) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp index 9541c77a87380..0304dfd8f0f4c 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp @@ -309,7 +309,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { m.def("my_amax(Tensor a) -> Tensor"); m.def("my_amax_vec(Tensor a) -> Tensor"); m.def("my_is_cpu(Tensor t) -> bool"); - m.def("test_default_constructor(bool undefined) -> bool"); + m.def("test_default_constructor(bool undefined) -> bool"); } bool test_default_constructor(bool defined) { @@ -331,7 +331,6 @@ bool test_default_constructor(bool defined) { return out.defined(); } - STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { m.impl("my_zero_", TORCH_BOX(&my_zero_)); m.impl("my_amax", TORCH_BOX(&my_amax)); diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 2ba1200f230d7..48ede590cecbf 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -15,11 +15,38 @@ install_cpp_extension, IS_WINDOWS, run_tests, + skipIfTorchDynamo, TestCase, xfailIfTorchDynamo, ) +def get_supported_dtypes(): + """Return a list of dtypes that are supported by torch stable ABI.""" + return [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.bfloat16, + torch.float16, + torch.float32, + torch.float64, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, + torch.complex32, + torch.complex64, + torch.complex128, + torch.bool, + ] + + def skipIfTorchVersionLessThan(major, minor): """Skip test if PyTorch version is less than specified version.""" @@ -700,6 +727,45 @@ def test_mv_tensor_accessor(self, device): expected = torch.mv(m, v) self.assertEqual(result, expected) + @skipIfTorchVersionLessThan(2, 10) + @skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor") + def test_get_any_data_ptr(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + t = torch.empty(2, 5, device=device, dtype=torch.float32) + expected_p = t.data_ptr() + + for mutable in [True, False]: + p = libtorch_agnostic.ops.get_any_data_ptr(t, mutable) + self.assertEqual(p, expected_p) + + @skipIfTorchVersionLessThan(2, 10) + @skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor") + def test_get_template_any_data_ptr(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + supported_dtypes = get_supported_dtypes() + + for dtype in supported_dtypes: + t = torch.empty(2, 5, device=device, dtype=dtype) + expected_p = t.data_ptr() + + for rdtype in supported_dtypes: + if dtype == rdtype: + for mutable in [True, False]: + p = libtorch_agnostic.ops.get_template_any_data_ptr( + t, rdtype, mutable + ) + self.assertEqual(p, expected_p) + else: + for mutable in [True, False]: + with self.assertRaisesRegex( + RuntimeError, "expected scalar type.* but found" + ): + libtorch_agnostic.ops.get_template_any_data_ptr( + t, rdtype, mutable + ) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index a526bebd26144..9791006d1498f 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -22,7 +22,7 @@ void check_all_parameters( template Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) { - auto& ops = torch::jit::getAllOperatorsFor( + auto ops = torch::jit::getAllOperatorsFor( torch::jit::Symbol::fromQualString(op_name)); TORCH_INTERNAL_ASSERT(ops.size() == 1); diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 0a108590bc5ed..9375c86d35584 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -65,7 +65,6 @@ device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" -curr_backend = dist.get_default_backend_for_device(device_type) class SimpleModel(nn.Module): @@ -423,10 +422,10 @@ class TestFullyShard2DStateDict(DTensorTestBase): @property def backend(self): # need to specify gloo backend for testing cpu offload - return f"cpu:gloo,{device_type}:{curr_backend}" + return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl" - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_fully_shard_tp_2d_set_full_state_dict(self): dummy_model = SimpleModel().to(device_type) mesh_2d = init_device_mesh( @@ -515,8 +514,8 @@ def _check_module(self, m1, m2, check_grad=False): ).to_local() self.assertEqual(param_m2, param_m1) - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_2d_ddp_integration_functionality(self) -> None: model, twod_model, dp_pg = self.init_model(self.device_type) optim = torch.optim.Adam(model.parameters(), lr=3e-5) @@ -567,8 +566,8 @@ def _compare_params(self, m1, m2): p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local() self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_2d_fsdp_state_enable_extension(self): mesh_2d = init_device_mesh( self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") @@ -643,18 +642,18 @@ def _test_2d_e2e_training( # Ensure all params are still the same after optimizer update. self._compare_params(model, model_2d) - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_2d_e2e_training_default(self): self._test_2d_e2e_training() - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_2d_e2e_training_use_orig_params(self): self._test_2d_e2e_training(use_orig_params=True) - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_2d_e2e_training_not_use_orig_params(self): # TODO: need to revisit input_reshard API about why it failed multi-gpu tests. # self._test_2d_e2e_training(recompute_activation=True) @@ -667,10 +666,10 @@ class TestNew2dParallelStateDict(DTensorTestBase): @property def backend(self): # need to specify gloo backend for testing cpu offload - return f"cpu:gloo,{device_type}:{curr_backend}" + return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl" - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) def test_fsdp_2d_extension(self): """ Test whether _fsdp_extension from FSDPstate has been set correctly. @@ -701,8 +700,8 @@ def test_fsdp_2d_extension(self): model_1d_fsdp_state = _get_module_fsdp_state(model_1d) self.assertEqual(model_1d_fsdp_state._fsdp_extension, None) - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) @parametrize("is_even_sharded_model", [True, False]) def test_2d_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven @@ -757,8 +756,8 @@ def test_2d_state_dict(self, is_even_sharded_model): torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True ) - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) @parametrize("is_even_sharded_model", [True, False]) def test_2d_load_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven @@ -812,8 +811,8 @@ def test_2d_load_state_dict(self, is_even_sharded_model): self.assertEqual(v1.device_mesh, v2.device_mesh) self.assertEqual(v1.placements, v2.placements) - @skip_if_lt_x_gpu(4) @with_comms + @skip_if_lt_x_gpu(4) @parametrize("is_even_sharded_model", [True, False]) def test_2d_optim_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven @@ -900,9 +899,9 @@ def test_2d_optim_state_dict(self, is_even_sharded_model): else: self.assertEqual(new_state, state) - @skip_if_lt_x_gpu(4) @with_comms @with_temp_dir + @skip_if_lt_x_gpu(4) def test_fsdp1_tp_2d_set_full_state_dict(self): """ This is a workaround for loading full state dict into a FSDP1+TP 2D model. diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index 9ddbe867fa879..a66518fc0ef0f 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -29,8 +29,8 @@ parallelize_module, RowwiseParallel, ) +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( - at_least_x_gpu, MultiProcessTestCase, requires_accelerator_dist_backend, skip_if_lt_x_gpu, @@ -40,6 +40,7 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, + TEST_XPU, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir @@ -106,9 +107,11 @@ def world_size(self): def device(self): return self.rank - @requires_accelerator_dist_backend() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs" + ) def test_pp_and_dcp(self): """ Test that pipeline parallelism and distributed checkpointing can be used together and @@ -198,9 +201,11 @@ def _dcp_test(self): _dcp_test(self) - @requires_accelerator_dist_backend() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -350,9 +355,11 @@ def apply_tp( torch.distributed.destroy_process_group() - @requires_accelerator_dist_backend() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -543,9 +550,11 @@ def apply_same_precision(partial_model): torch.distributed.destroy_process_group() - @requires_accelerator_dist_backend() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" + ) @parametrize( "ScheduleClass", [ diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index 2099a2a2d44d9..89a893037c3b5 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import os import sys import torch @@ -17,8 +18,8 @@ ) from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_distributed import ( - DistributedTestBase, - requires_accelerator_dist_backend, + MultiProcessTestCase, + requires_nccl, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN @@ -29,12 +30,9 @@ sys.exit(0) -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" - - def gpus_for_rank(world_size): - visible_devices = list(range(torch.accelerator.device_count())) - gpus_per_process = torch.accelerator.device_count() // world_size + visible_devices = list(range(torch.cuda.device_count())) + gpus_per_process = torch.cuda.device_count() // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( @@ -62,7 +60,27 @@ def forward(self, x, rank): return self.t0(x ** (1 + rank)) -class DistributedDataParallelCommHookTest(DistributedTestBase): +class DistributedDataParallelCommHookTest(MultiProcessTestCase): + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + try: + os.remove(self.file_name) + except OSError: + pass + + def _get_process_group_nccl(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + return dist.distributed_c10d._get_default_group() + @property def world_size(self): return 2 @@ -101,14 +119,14 @@ def _run_and_get_grads(self, model): param = next(model.parameters()) return param.grad - @requires_accelerator_dist_backend() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_allreduce_hook(self): """ This unit test verifies the ``allreduce`` hook registered case gives same result with no hook registered case. """ - process_group = self.create_pg(device_type) + process_group = self._get_process_group_nccl() # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -117,14 +135,14 @@ def test_ddp_comm_hook_allreduce_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) - @requires_accelerator_dist_backend() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_fp16compress_hook(self): """ This unit test verifies the ``fp16 compress`` hook registered case gives close result with no hook registered case. """ - process_group = self.create_pg(device_type) + process_group = self._get_process_group_nccl() # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -133,14 +151,14 @@ def test_ddp_comm_hook_fp16compress_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) - @requires_accelerator_dist_backend() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_quantize_per_tensor_hook(self): """ This unit test verifies the ``quantize per tensor`` hook registered case gives close result with no hook registered case. """ - process_group = self.create_pg(device_type) + process_group = self._get_process_group_nccl() # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -149,14 +167,14 @@ def test_ddp_comm_hook_quantize_per_tensor_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) - @requires_accelerator_dist_backend() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_quantize_per_channel_hook(self): """ This unit test verifies the ``quantize per channel`` hook registered case gives close result with no hook registered case. """ - process_group = self.create_pg(device_type) + process_group = self._get_process_group_nccl() # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -167,14 +185,14 @@ def test_ddp_comm_hook_quantize_per_channel_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) - @requires_accelerator_dist_backend() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_noop_hook(self): """ This unit test verifies the ``noop`` hook registered case and a subsequent allreduce gives same result with no hook registered case. """ - process_group = self.create_pg(device_type) + process_group = self._get_process_group_nccl() # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -186,10 +204,10 @@ def test_ddp_comm_hook_noop_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) - @requires_accelerator_dist_backend() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_is_last_hook(self): - process_group = self.create_pg(device_type) + process_group = self._get_process_group_nccl() def hook(flags, bucket): flags.append(bucket.is_last()) diff --git a/test/distributed/checkpoint/test_state_dict_utils.py b/test/distributed/checkpoint/test_state_dict_utils.py index c0f850cf95c9c..76e9aeb9e3302 100644 --- a/test/distributed/checkpoint/test_state_dict_utils.py +++ b/test/distributed/checkpoint/test_state_dict_utils.py @@ -32,7 +32,7 @@ class TestStateDictUtils(DTensorTestBase): @property def world_size(self): - return min(4, torch.accelerator.device_count()) + return min(4, torch.cuda.device_count()) @with_comms @skip_if_lt_x_gpu(2) @@ -49,7 +49,7 @@ def test_gather_state_dict_dtensor(self): dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"]) - self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type) + self.assertTrue(gathered_state_dict["dtensor"].is_cuda) @with_comms @skip_if_lt_x_gpu(4) @@ -69,16 +69,14 @@ def test_gather_with_cpu_and_ranks_only(self): ) if dist.get_rank() in (0, 2): self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"]) - self.assertNotEqual( - gathered_state_dict["dtensor"].device.type, self.device_type - ) + self.assertFalse(gathered_state_dict["dtensor"].is_cuda) else: self.assertEqual(gathered_state_dict, {}) @with_comms @skip_if_lt_x_gpu(4) def test_cpu_and_ranks_only(self): - device = torch.device(self.device_type) + device = torch.device("cuda") state_dict = { "tensor1": torch.arange(10, device=device), "tensor2": torch.ones(10, device=device), @@ -87,7 +85,7 @@ def test_cpu_and_ranks_only(self): cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2)) if dist.get_rank() in (0, 2): for v in cpu_state_dict.values(): - self.assertNotEqual(v.device.type, self.device_type) + self.assertFalse(v.is_cuda) self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10)) self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10)) else: @@ -111,27 +109,27 @@ def create_dtensor(): for _ in range(10): tensor, dtensor = create_dtensor() ltensor.append(tensor) - ltensor.append(torch.ones(10, device=torch.device(self.device_type))) + ltensor.append(torch.ones(10, device=torch.device("cuda"))) ldtensor.append(dtensor) - ldtensor.append(torch.ones(10, device=torch.device(self.device_type))) + ldtensor.append(torch.ones(10, device=torch.device("cuda"))) tensor, dtensor = create_dtensor() dist_state_dict = { "local": dtensor, "list": ldtensor, - "arange": torch.arange(10, device=torch.device(self.device_type)), + "arange": torch.arange(10, device=torch.device("cuda")), } state_dict = { "local": tensor, "list": ltensor, - "arange": torch.arange(10, device=torch.device(self.device_type)), + "arange": torch.arange(10, device=torch.device("cuda")), } self.assertEqual(state_dict, _gather_state_dict(dist_state_dict)) @with_comms @skip_if_lt_x_gpu(2) def test_create_cpu_state_dict(self): - device = torch.device(self.device_type) + device = torch.device("cuda") rank = dist.get_rank() # Scale tensors based on world size # to fit in the tensor shards accurately. @@ -151,7 +149,7 @@ def test_create_cpu_state_dict(self): metadata=ShardMetadata( shard_offsets=[5 * rank, 0], shard_sizes=[5, 10], - placement=f"rank:{rank}/{self.device_type}:{rank}", + placement=f"rank:{rank}/cuda:{rank}", ), ) ], @@ -161,7 +159,7 @@ def test_create_cpu_state_dict(self): torch.arange(50 * scale_factor, device=device).reshape( 5 * scale_factor, 10 ), - init_device_mesh(self.device_type, mesh_shape=(self.world_size,)), + init_device_mesh("cuda", mesh_shape=(self.world_size,)), [Shard(0)], ), "non_tensor_bytes_io": copy.deepcopy(buffer), @@ -247,7 +245,7 @@ def test_state_dict_util_distribute_tensors(self): even_tensor = torch.randn(self.world_size, 2) uneven_tensor = torch.randn(1, 2) - mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,)) even_dtensor = distribute_tensor( torch.randn(self.world_size, 2), mesh, [Shard(0)] ) @@ -275,10 +273,10 @@ def test_state_dict_util_distribute_tensors(self): @with_comms @skip_if_lt_x_gpu(2) def test_cpu_offload_for_dtensor(self): - device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,)) sd = { "k": DTensor.from_local( - torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)] + torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)] ) } cpu_sd = _create_cpu_state_dict(sd) @@ -292,12 +290,12 @@ def test_cpu_offload_for_dtensor(self): self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"])) _copy_state_dict(sd, cpu_sd, non_blocking=True) - torch.accelerator.synchronize() + torch.cuda.synchronize() self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"])) sd["k"] += 1 self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"])) _copy_state_dict(sd, cpu_sd, non_blocking=True) - torch.accelerator.synchronize() + torch.cuda.synchronize() self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"])) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 6f527dbb0257f..35eefdad512e6 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -7,7 +7,7 @@ import copy import sys -from contextlib import contextmanager, nullcontext +from contextlib import nullcontext from typing import Any, cast import numpy as np @@ -40,6 +40,7 @@ skip_if_rocm_multiprocess, skip_if_win32, ) +from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -56,17 +57,7 @@ HAS_TORCHVISION = False -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" - - -@contextmanager -def deterministic_algorithms(enabled=True): - prev_state = torch.are_deterministic_algorithms_enabled() - torch.use_deterministic_algorithms(enabled) - try: - yield - finally: - torch.use_deterministic_algorithms(prev_state) +device_type = str(get_devtype()) class TestZeroRedundancyOptimizer(DistributedTestBase): @@ -1250,7 +1241,7 @@ def _test_ddp_zero_overlap( enabled=True, deterministic=True, benchmark=False ) if "cuda" in device - else deterministic_algorithms(True) + else torch.use_deterministic_algorithms(True) ) with det_ctx: device_ids = [rank] if requires_ddp_rank(device) else None diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 473198e5421c5..0877eb53cd6f5 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -24,7 +24,7 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.common_device_type import e4m3_type from torch.testing._internal.common_distributed import ( - DistributedTestBase, + MultiProcessTestCase, requires_accelerator_dist_backend, skip_if_lt_x_gpu, ) @@ -59,8 +59,12 @@ def load_test_module(name): sys.exit(0) -@requires_accelerator_dist_backend() -class TestWithNCCL(DistributedTestBase): +@requires_accelerator_dist_backend(["nccl", "xccl"]) +class TestWithNCCL(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + @property def world_size(self) -> int: return 2 @@ -74,7 +78,16 @@ def device(self) -> torch.device: return torch.device(self.rank) def _init_process_group(self) -> None: - self.create_pg(self.device.type) + torch.accelerator.set_device_index(self.rank) + store = dist.FileStore(self.file_name, self.world_size) + backend = dist.get_default_backend_for_device(self.device.type) + + dist.init_process_group( + backend=backend, + world_size=self.world_size, + rank=self.rank, + store=store, + ) torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) @skip_if_lt_x_gpu(2) diff --git a/test/distributed/test_c10d_object_collectives.py b/test/distributed/test_c10d_object_collectives.py index 9ef04b61ab23b..594564c456068 100644 --- a/test/distributed/test_c10d_object_collectives.py +++ b/test/distributed/test_c10d_object_collectives.py @@ -11,10 +11,13 @@ print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS from torch.testing._internal.common_utils import ( run_tests, skipIfHpu, + TEST_CUDA, + TEST_HPU, TEST_WITH_DEV_DBG_ASAN, ) @@ -26,8 +29,16 @@ ) sys.exit(0) -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" -device_count = torch.accelerator.device_count() +if TEST_HPU: + DEVICE = "hpu" +elif TEST_CUDA: + DEVICE = "cuda" +else: + DEVICE = "cpu" + +device_module = torch.get_device_module(DEVICE) +device_count = device_module.device_count() +BACKEND = dist.get_default_backend_for_device(DEVICE) def with_comms(func=None): @@ -38,10 +49,11 @@ def with_comms(func=None): @wraps(func) def wrapper(self, *args, **kwargs): - if device_type != "cpu" and device_count < self.world_size: + if DEVICE != "cpu" and device_count < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - self.pg = self.create_pg(device=device_type) + kwargs["device"] = DEVICE + self.pg = self.create_pg(device=DEVICE) try: return func(self, *args, **kwargs) finally: @@ -52,7 +64,7 @@ def wrapper(self, *args, **kwargs): class TestObjectCollectives(DistributedTestBase): @with_comms() - def test_all_gather_object(self): + def test_all_gather_object(self, device): output = [None] * dist.get_world_size() dist.all_gather_object(object_list=output, obj=self.rank) @@ -60,7 +72,7 @@ def test_all_gather_object(self): self.assertEqual(i, v, f"rank: {self.rank}") @with_comms() - def test_gather_object(self): + def test_gather_object(self, device): output = [None] * dist.get_world_size() if self.rank == 0 else None dist.gather_object(obj=self.rank, object_gather_list=output) @@ -70,7 +82,7 @@ def test_gather_object(self): @skipIfHpu @with_comms() - def test_send_recv_object_list(self): + def test_send_recv_object_list(self, device): val = 99 if self.rank == 0 else None object_list = [val] * dist.get_world_size() if self.rank == 0: @@ -84,7 +96,7 @@ def test_send_recv_object_list(self): self.assertEqual(None, object_list[0]) @with_comms() - def test_broadcast_object_list(self): + def test_broadcast_object_list(self, device): val = 99 if self.rank == 0 else None object_list = [val] * dist.get_world_size() # TODO test with broadcast_object_list's device argument @@ -93,7 +105,7 @@ def test_broadcast_object_list(self): self.assertEqual(99, object_list[0]) @with_comms() - def test_scatter_object_list(self): + def test_scatter_object_list(self, device): input_list = list(range(dist.get_world_size())) if self.rank == 0 else None output_list = [None] dist.scatter_object_list( @@ -111,30 +123,34 @@ def setup_sub_pg(self): my_pg = dist.new_group(ranks, use_local_synchronization=True) return rank, ranks, my_pg + @skipIfHpu @with_comms() - def test_subpg_scatter_object(self): + def test_subpg_scatter_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg) self.assertEqual(rank, out_list[0]) + @skipIfHpu @with_comms() - def test_subpg_all_gather_object(self): + def test_subpg_all_gather_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] * len(ranks) dist.all_gather_object(out_list, rank, group=my_pg) self.assertEqual(ranks, out_list) + @skipIfHpu @with_comms() - def test_subpg_gather_object(self): + def test_subpg_gather_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] * len(ranks) if rank == ranks[0] else None dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg) if rank == ranks[0]: self.assertEqual(ranks, out_list) + @skipIfHpu @with_comms() - def test_subpg_broadcast_object(self): + def test_subpg_broadcast_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] if rank == ranks[0]: @@ -143,5 +159,7 @@ def test_subpg_broadcast_object(self): self.assertEqual(ranks[0], out_list[0]) +devices = ("cpu", "cuda", "hpu") +instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices) if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 396e49949deb5..a0de1b13c6161 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -29,7 +29,7 @@ ) from torch.distributed.tensor.placement_types import _Partial, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase +from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran os.environ["LOCAL_RANK"] = f"{local_rank}" -@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.") +@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.") class DeviceMeshTestGlooBackend(DTensorTestBase): @property def backend(self): diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 84b468afcfa2d..0117f67c38c11 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -40,6 +40,7 @@ DynamoDistributedSingleProcTestCase, MultiProcessTestCase, requires_accelerator_dist_backend, + requires_gloo, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( @@ -1773,16 +1774,10 @@ def func(x, w, ar_0, ar_1, tag, ranks, group_size): inputs = [x, w, ar_0, ar_1] f(*inputs, **self.get_world_trs()) - def _pass(g): - from torch._inductor.fx_passes.bucketing import bucket_all_reduce - - bucket_all_reduce(g.owning_module, lambda _: 2000) - - torch._inductor.config.post_grad_custom_post_pass = _pass - with torch._inductor.config.patch( { "reorder_for_compute_comm_overlap": False, + "bucket_all_reduces_fx": bucket_mode, } ): compiled = torch.compile(f) @@ -2234,6 +2229,50 @@ def func(inp, group_size, group_name): ) assert est_ms_nccl > 0 + @skip_if_lt_x_gpu(2) + @requires_gloo() + def test_regression_use_nccl_estimate_with_gloo(self): + # Test checks that using nccl estimator option does not hard fail + # with backends that does not support runtime estimations, e.g. gloo + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="gloo", store=store, rank=self.rank, world_size=self.world_size + ) + group = c10d.distributed_c10d._get_default_group() + group_name = "default" + torch._C._distributed_c10d._register_process_group( + group_name, torch.distributed.group.WORLD + ) + group_size = group.size() + + def func(inp, group_size, group_name): + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( + inp, group_size, group_name + ) + ag_0_wait = torch.ops.c10d_functional.wait_tensor(ag_0_out) + ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0_wait, group_size, group_name + ) + ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out) + return ag_1_wait + + gm = make_fx(func)(torch.ones(4, 4), group_size, group_name) + g = gm.graph + for n in g.nodes: + if is_all_gather_into_tensor(n): + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_fake_distributed.py b/test/dynamo/test_fake_distributed.py index 41e373a50d76b..6a5a189c3bea1 100644 --- a/test/dynamo/test_fake_distributed.py +++ b/test/dynamo/test_fake_distributed.py @@ -90,12 +90,12 @@ def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", floordiv: "Sym(( """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"): - ge_1: "Sym(u0 >= 0)" = primals_1 >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None - ge_3: "Sym(u1 >= 0)" = primals_2 >= 0 - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None - ge_5: "Sym(u2 >= 0)" = primals_3 >= 0 - _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None + ge: "Sym(u0 >= 0)" = primals_1 >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None + ge_1: "Sym(u1 >= 0)" = primals_2 >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None + ge_2: "Sym(u2 >= 0)" = primals_3 >= 0 + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None floordiv: "Sym((u0//2))" = primals_1 // 2 diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 68e3a39800b6e..b34cae52d4c5f 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -727,7 +727,7 @@ def k(x): x = torch.randn(3) arg_count = ifdynstaticdefault(4, 5) # when compiled with dynamic, we don't have upper bound runtime assertions for u0 - expected_op_count = ifdynstaticdefault(10, 8) + expected_op_count = ifdynstaticdefault(9, 7) out_graph = self._test_wrap_simple( f, default_args_generator((x,)), @@ -747,7 +747,6 @@ def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"): c: "i64[u0, 1]" = l_x_.nonzero() sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0) - _check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None @@ -784,7 +783,6 @@ def forward(self, L_x_: "f32[3]"): c: "i64[u0, 1]" = l_x_.nonzero() sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0) - _check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None @@ -883,7 +881,7 @@ def k(x): x = torch.randn(3) arg_count = ifdynstaticdefault(4, 5) # when compiled with dynamic, we don't have upper bound runtime assertions for u0 - expected_op_count = ifdynstaticdefault(10, 8) + expected_op_count = ifdynstaticdefault(9, 7) out_graph = self._test_wrap_simple( f, default_args_generator((x,)), @@ -905,7 +903,6 @@ def forward(self, L_x_: "f32[3]"): c: "i64[u0, 1]" = l_x_.nonzero() sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0) - _check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None ge: "Sym(u0 >= 0)" = sym_size_int >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None @@ -956,7 +953,7 @@ def k(x): y = torch.randn(3) arg_count = ifdynstaticdefault(5, 6) # when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1 - expected_op_count = ifdynstaticdefault(17, 13) + expected_op_count = ifdynstaticdefault(15, 11) out_graph = self._test_wrap_simple( f, default_args_generator((x, y)), @@ -977,7 +974,6 @@ def forward(self, L_x_: "f32[3]", L_y_: "f32[3]"): c: "i64[u0, 1]" = l_x_.nonzero() sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0) - _check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None @@ -987,7 +983,6 @@ def forward(self, L_x_: "f32[3]", L_y_: "f32[3]"): d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0) - _check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0 _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None diff --git a/test/export/test_export.py b/test/export/test_export.py index 204d458e77704..c60c8e82cc011 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3081,15 +3081,12 @@ def forward(self, x, y): foo = torch.ops.export.foo.default(x, y); x = None sym_size_int = torch.ops.aten.sym_size.int(foo, 0) sym_size_int_1 = torch.ops.aten.sym_size.int(foo, 1) - sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int); sym_constrain_range_for_size_default = None ge = sym_size_int >= 0; sym_size_int = None _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None - sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default_1 = None ge_1 = sym_size_int_1 >= 0; sym_size_int_1 = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None bar = torch.ops.export.bar.default(y); y = None sym_size_int_2 = torch.ops.aten.sym_size.int(bar, 0) - sym_constrain_range_for_size_default_2 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_2); sym_constrain_range_for_size_default_2 = None ge_2 = sym_size_int_2 >= 0; sym_size_int_2 = None _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_2 = None return (foo, bar)""", @@ -17743,7 +17740,6 @@ def forward(self, x, mask): def forward(self, x, mask): masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0) - sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None ge = sym_size_int_1 >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None le = sym_size_int_1 <= 1188864 diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 472ddcf556f83..1280ab45f2a82 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -38,6 +38,7 @@ _to_json_bytes, canonicalize, deserialize, + deserialize_torch_artifact, ExportedProgramDeserializer, ExportedProgramSerializer, GraphModuleSerializer, @@ -1904,6 +1905,16 @@ def forward(self, x): self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) + def test_deserialize_torch_artifact_dict(self): + data = {"key": torch.tensor([1, 2, 3])} + buf = io.BytesIO() + torch.save(data, buf) + serialized = buf.getvalue() + result = deserialize_torch_artifact(serialized) + + self.assertIsInstance(result, dict) + self.assertTrue(torch.equal(result["key"], torch.tensor([1, 2, 3]))) + @unittest.skipIf(IS_WINDOWS, "Cannot modify file in windows") def test_save_file(self): class Foo(torch.nn.Module): @@ -2010,7 +2021,6 @@ def forward(self, x): save(ep, buffer) buffer.seek(0) loaded_ep = load(buffer) - inp = (torch.tensor(1),) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index 6025c90cdb4a2..474d3986eb7ad 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -1492,8 +1492,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1) nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) - ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None _to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1] @@ -1513,8 +1513,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"): clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1) nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) - ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u0 >= 0)" = sym_size_int >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None _to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None @@ -1538,8 +1538,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1) sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) - ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1) alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type) @@ -1557,8 +1557,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): def forward(self, arg0_1: "f32[2][1]cpu"): nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1) sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) - ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u0 >= 0)" = sym_size_int >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 405e46d8ded52..079be79fcc9d8 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -235,7 +235,6 @@ def mm(a, b): Y_eager = a @ b torch.testing.assert_close(Y_compiled, Y_eager, equal_nan=True) - @unittest.skip("Autotune Mismatch being investigated") @unittest.skipIf(not torch.version.hip, "ROCM only") @unittest.mock.patch.dict(os.environ, _test_env) @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index c1b41fd8ec5c3..4b9030b5cae4b 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -479,17 +479,14 @@ def test_remote_cache_load_function( if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if ( - device == "cuda" - and torch.version.hip is None - and dtype == torch.bfloat16 - and not SM80OrLater - ): + if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: raise unittest.SkipTest("requires SM80 or later") if use_static_cuda_launcher and not (device == "cuda" and bundle_triton): raise unittest.SkipTest( "Static cuda launcher requires cuda and triton bundling" ) + if use_static_cuda_launcher and TEST_WITH_ROCM: + raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM") def fn(x, y): return (x * 2, y @ y) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index d01d57f06a762..ba9dc93c651cf 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4449,16 +4449,17 @@ def __init__(self): def forward(self, x): return self.gn(x) - for dynamic in [True, False]: - torch._dynamo.reset() - metrics.reset() - mod = M().eval() - x = torch.randn(1, 32, 128, 128, 128) - with torch.no_grad(): - expected = mod(x) - compiled_m = torch.compile(mod, dynamic=dynamic) - actual = compiled_m(x) - self.assertEqual(expected, actual) + for simdlen, dynamic in itertools.product([None, 0], [True, False]): + with config.patch({"cpp.simdlen": simdlen}): + torch._dynamo.reset() + metrics.reset() + mod = M().eval() + x = torch.randn(1, 32, 128, 128, 128) + with torch.no_grad(): + expected = mod(x) + compiled_m = torch.compile(mod, dynamic=dynamic) + actual = compiled_m(x) + self.assertEqual(expected, actual) @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True diff --git a/test/inductor/test_deterministic.py b/test/inductor/test_deterministic.py index 0de777dd81b5c..382838c31bed4 100644 --- a/test/inductor/test_deterministic.py +++ b/test/inductor/test_deterministic.py @@ -1,5 +1,9 @@ # Owner(s): ["module: inductor"] import contextlib +import os +import subprocess +import sys +import tempfile import unittest import torch @@ -104,6 +108,64 @@ def foo(x): else: self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0) + @parametrize("model_name", ["GoogleFnet", "BertForMaskedLM", "DistillGPT2"]) + @parametrize("training_or_inference", ["training", "inference"]) + @parametrize("precision", ["float32", "bfloat16", "float16", "amp"]) + def test_run2run_determinism(self, model_name, training_or_inference, precision): + """ + Test run2run determinism for a few huggingface models. + + The test assumes benchmarks/dynamo/huggingface.py can be found from + the current working directory. + """ + + if not os.path.exists("benchmarks/dynamo/huggingface.py"): + self.skipTest("Skip due to benchmarks/dynamo/huggingface.py not found.") + + def _setup_env(env): + env["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" # disable autotune cache + env["TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE"] = "0" + env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0" + if enable_determinism: + env["TORCHINDUCTOR_DETERMINISTIC"] = "1" + + # set to false if you want to check how the test fails without + # the deterministic mode + enable_determinism = True + with tempfile.TemporaryDirectory() as tmpdir: + saved_pkl = os.path.join(tmpdir, "saved.pkl") + cmd = ( + f"{sys.executable} benchmarks/dynamo/huggingface.py --backend inductor" + + f" --{precision} --accuracy --only {model_name} --{training_or_inference}" + + f" --disable-cudagraphs --save-model-outputs-to={saved_pkl}" + ) + print("Command", cmd) + env = os.environ.copy() + _setup_env(env) + out = subprocess.run(cmd.split(), capture_output=True, env=env) + + # We don't check the accuracy against eager here because some + # of the combination between model and precision can not + # pass that accuracy test. But it's still valuable to make + # sure we generate bitwise equivalent result from run to run. + # self.assertTrue("pass" in out.stdout.decode()) + + cmd = ( + f"{sys.executable} benchmarks/dynamo/huggingface.py --backend inductor" + + f" --{precision} --accuracy --only {model_name} --{training_or_inference}" + + f" --disable-cudagraphs --compare-model-outputs-with={saved_pkl}" + ) + print("Command", cmd) + + # distort benchmarking results + env["TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT"] = "inverse" + out = subprocess.run(cmd.split(), capture_output=True, env=env) + self.assertTrue( + "The result is bitwise equivalent to the previously saved result" + in out.stdout.decode(), + f"stdout: {out.stdout.decode()}, stderr: {out.stderr.decode()}", + ) + if __name__ == "__main__": if HAS_CUDA_AND_TRITON: diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index f9e84284f0d8d..2c232594f3329 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -831,7 +831,9 @@ def check( gm = torch._inductor.aot_compile( ep.module(), inp, options={"fx_wrapper": True, **test_config} ) - self.assertTrue(same(model(*inp), gm(*inp))) + # Flatten args for fx_wrapper gm + flat_args, _ = pytree.tree_flatten(inp) + self.assertTrue(same(model(*inp), gm(*flat_args))) for node in gm.graph.nodes: if ( @@ -1182,6 +1184,38 @@ def mock_set_hook(gm: torch.fx.GraphModule, fn): compiled_out = compiled(*args) self.assertEqual(compiled_out.shape, shape) + def test_reshape_dynamic_ph(self): + """ + Test dynamic scalars using SymInts placeholder + """ + + class TestModule(torch.nn.Module): + def forward(self, x, shape): + return torch.reshape(x, shape) + 2 + + ds = { + "x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO), + "shape": [torch.export.Dim.AUTO, torch.export.Dim.AUTO], + } + args = (torch.randn((12, 14), device=self.device), [6, 28]) + self.check(TestModule(), args, ds) + + def test_reshape_dynamic_tmd(self): + """ + Test dynamic reshape using shape dependent information + """ + + class TestModule(torch.nn.Module): + def forward(self, x): + new_shape = [x.shape[0] // 2, x.shape[1] * 2] + return torch.reshape(x, new_shape) + 2 + + ds = { + "x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO), + } + args = (torch.randn((12, 14), device=self.device),) + self.check(TestModule(), args, ds) + class TestReplaceFloorDiv(InductorTestCase): """ diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 8356ecd0b6998..051a5f5905997 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -1188,20 +1188,6 @@ def fn(nodes): with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): torch.compile(f)(x) - def test_find_broadcast_var(self): - """Test broadcast variable detection for tiling improvements.""" - from torch._inductor import tiling_utils - - i, j = sympy.symbols("i j", integer=True) - - # Test broadcast pattern detection: FloorDiv creates broadcast - result = tiling_utils.find_broadcast_var(FloorDiv(i, 10), {i: 100, j: 50}) - self.assertEqual(result, i) - - # Test non-broadcast: linear access pattern - result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8}) - self.assertEqual(result, None) - class TestIndexInversion(TestCase): @classmethod diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 592e42ce41735..1114810ceccdf 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -270,11 +270,20 @@ def f(x, y): ], ) @parametrize("split_reductions", (False, True)) - @parametrize("shape", ((32768, 2048), (32768, 768), (32768 + 1023, 768))) + @parametrize( + "shape", ((1000000, 256), (32768, 2048), (32768, 768), (32768 + 1023, 768)) + ) @parametrize("max_autotune", (False, True)) @parametrize("initial_xblock", (1, 2)) + @parametrize("add_1dim", (False, True)) def test_rms_norm_bwd( - self, wdtype, split_reductions, shape, max_autotune, initial_xblock + self, + wdtype, + split_reductions, + shape, + max_autotune, + initial_xblock, + add_1dim, ): # max_autotune can be slow and cost resource, trim down the tests # for max autotune @@ -287,6 +296,9 @@ def test_rms_norm_bwd( ): self.skipTest("Skip non-critical tests to save resources.") + if shape != (1000000, 256) and add_1dim: + self.skipTest("Skip non-critical tests to save resources.") + def f(x, w, eps): orig_dtype = x.dtype @@ -307,6 +319,9 @@ def fwd_bwd(f): # M, N = 1152 * 500, 384 M, N = shape x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True) + if add_1dim: + x = x[:, None, :] + w = torch.randn(N, dtype=wdtype, device=GPU_TYPE, requires_grad=True) dy = torch.randn_like(x) eps = 1e-5 diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index ec9586197d085..654bfd269f761 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -12,6 +12,7 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -38,9 +39,8 @@ def write_cubin_to_tmp(self, kernel: CompiledKernel) -> str: # Just used by tests for now. # TODO: derive cubin_path from wherever triton stores the cubin file on disk. tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False) - binary_key = "hsaco" if torch.version.hip else "cubin" with tmp_file: - tmp_file.write(kernel.asm[binary_key]) + tmp_file.write(kernel.asm["cubin"]) self.tmp_files.append(tmp_file) return tmp_file.name @@ -64,6 +64,7 @@ def _make_launcher( result.load_kernel(device_interface.current_device()) return result + @skipIfRocm def test_basic(self): @triton.jit def simple_kernel(arg0, arg1): @@ -90,6 +91,7 @@ def simple_kernel(arg0, arg1): # 2. triton relies on inspect.get_source to get the type annotations # so I can't even use exec() to generate the test cases. # So we'll just make a few kernels by hand + @skipIfRocm def test_unsigned_integers(self): @triton.jit def unsigned_integers( @@ -113,6 +115,7 @@ def unsigned_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_signed_integers(self): @triton.jit def signed_integers( @@ -136,6 +139,7 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_basic_1arg(self): @triton.jit def simple_kernel_1_arg(arg0): @@ -160,6 +164,7 @@ def simple_kernel_1_arg(arg0): ) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_constexpr(self): # Constexprs are compiled directly into the cubin file, # so we never need to pass it to StaticCudaLauncher. @@ -188,6 +193,7 @@ def kernel_constexpr(arg0, CONSTANT: tl.constexpr): ) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_implied_constant(self): """xnumel is unused in this kernel, but isn't explicitly marked as a constexpr""" @@ -240,6 +246,7 @@ def triton_red_fused_any_isinf_0( launcher.run(1, 1, 1, stream, arg0, arg2, 128) self.assertEqual(arg1, arg2) + @skipIfRocm def test_kernel_no_args(self): # Just an easy way to test incompatible number of arguments @triton.jit @@ -252,6 +259,7 @@ def kernel_no_op(): stream = device_interface.get_raw_stream(device_interface.current_device()) launcher.run(1, 1, 1, stream) + @skipIfRocm def test_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -275,6 +283,7 @@ def simple_kernel(arg0, arg1): launcher.run(1, 1, 1, stream, new_arg0, arg1) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_too_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -294,6 +303,7 @@ def simple_kernel(arg0, arg1): lambda: self._make_launcher(compiled_kernel), ) + @skipIfRocm def test_kernel_empty_tensor(self): # Triton kernel generated by torch.compile of the following: # @torch.compile() @@ -354,6 +364,7 @@ def triton_poi_fused_cat_0( launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel) self.assertEqual(buf0, buf1) + @skipIfRocm def test_kernel_many_args(self): N = 200 # Make 200 arguments @@ -394,6 +405,7 @@ class TestStaticTritonCompileResult(TestCase): Tests static cuda launcher with torch.compile() """ + @skipIfRocm def test_basic_compile(self): @torch.compile def foo(x, y): @@ -403,6 +415,7 @@ def foo(x, y): y = torch.randn(10, device="cuda") self.assertEqual(foo(x, y), x + y) + @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch("compile_threads", 1) def test_incompatible_code(self): @@ -425,6 +438,7 @@ def foo(x): lambda: foo(x), ) + @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch( {"compile_threads": 1, "static_launch_user_defined_triton_kernels": True} @@ -446,6 +460,7 @@ def foo(x): x2 = x.clone().detach_() self.assertEqual(foo(x), x2 + 5) + @skipIfRocm def test_empty_tensor(self): @torch.compile() def foo(x, y): @@ -457,6 +472,7 @@ def foo(x, y): result = foo(x, y) self.assertEqual(result, torch.cat(((x * 4), y + 10))) + @skipIfRocm def test_any(self): def fn(x): return ( @@ -476,6 +492,7 @@ def fn(x): compiled_result = compiled_fn(arg) self.assertEqual(eager_result, compiled_result) + @skipIfRocm def test_disable_static_cuda_launcher(self): @torch.compile def fn(x, y): diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4132674993e1e..fc0bdd2c0be03 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14626,6 +14626,21 @@ def test_weight_norm_conv2d(self): self.assertTrue(same((ref, ref_grad), (act, act_grad), tol=1e-3)) + @skipIfMPS + def test_inner_reduction_detection(self): + if self.device == "cpu": + self.skipTest("Skip for CPU device") + + x = torch.randn(100000, 1, 256, device=self.device) + + @torch.compile + def f(x): + return x.sum(dim=(0, 1)) + + code = run_and_get_triton_code(f, x) + self.assertTrue("ReductionHint.OUTER" in code) + self.assertFalse("ReductionHint.INNER" in code) + @skip_if_halide @requires_cuda_and_triton @skip_if_cpp_wrapper("skip cpp wrapper") diff --git a/test/test_autograd.py b/test/test_autograd.py index e025a8e6e582d..5960ac8add36d 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -10895,6 +10895,34 @@ def func(inp): self.assertTrue(gradcheck(func, x, fast_mode=True)) + def test_grad_thread_safety(self): + import threading + from concurrent.futures import ThreadPoolExecutor + + NUM_ITERS = 10 + NUM_THREADS = 4 + + # Concurrent calls to tensor.untyped_storage() + def access_grad(tensor, barrier): + barrier.wait() + return weakref.ref(tensor.grad) + + for i in range(NUM_ITERS): + tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + (tensor**2).sum().backward() + + barrier = threading.Barrier(NUM_THREADS) + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = [ + executor.submit(access_grad, tensor, barrier) + for _ in range(NUM_THREADS) + ] + + # Check that all the grad tensors returned were the same + for future in futures: + self.assertEqual(future.result()(), tensor.grad) + self.assertIsNotNone(tensor.grad) + def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 56a4202cded3f..2b5606aec98d6 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -3539,7 +3539,7 @@ def _test_logaddexp(self, device, dtype, base2): if base2: ref_func = np.logaddexp2 our_func = torch.logaddexp2 - elif dtype in (torch.complex64, torch.complex128): + elif dtype in (torch.complex32, torch.complex64, torch.complex128): # numpy has not implemented logaddexp for complex def complex_logaddexp(x1, x2): x = np.stack((x1, x2)) @@ -3558,6 +3558,13 @@ def _test_helper(a, b): ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy()) v = our_func(a, b) self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01) + elif dtype == torch.complex32: + ref = ref_func( + a.cpu().to(torch.complex64).numpy(), + b.cpu().to(torch.complex64).numpy(), + ) + v = our_func(a, b) + self.assertEqual(ref, v.to(torch.complex64), atol=0.01, rtol=0.01) else: ref = ref_func(a.cpu().numpy(), b.cpu().numpy()) v = our_func(a, b) @@ -3588,12 +3595,23 @@ def _test_helper(a, b): _test_helper(a, b) @skipIfTorchDynamo() # complex infs/nans differ under Dynamo/Inductor - @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16) + @dtypesIfCUDA( + torch.float32, + torch.float64, + torch.bfloat16, + torch.complex32, + torch.complex64, + torch.complex128, + ) @dtypes( torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128 ) def test_logaddexp(self, device, dtype): - if sys.version_info >= (3, 12) and dtype in (torch.complex64, torch.complex128): + if sys.version_info >= (3, 12) and dtype in ( + torch.complex32, + torch.complex64, + torch.complex128, + ): return self.skipTest("complex flaky in 3.12") self._test_logaddexp(device, dtype, base2=False) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 8568920e8b196..d42d1cd56600a 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -3480,7 +3480,7 @@ def _run_ind_worker_queue_test(self, batch_size, num_workers): batch_size=batch_size, shuffle=False, num_workers=num_workers, - timeout=5, + timeout=JOIN_TIMEOUT, worker_init_fn=self.dataset.worker_init_fn, ) current_worker_idx = 0 @@ -3498,33 +3498,31 @@ def _run_ind_worker_queue_test(self, batch_size, num_workers): "Flaky on Windows and MacOS https://github.com/pytorch/pytorch/issues/68643", ) def test_ind_worker_queue(self): - max_num_workers = None - if hasattr(os, "sched_getaffinity"): - try: - max_num_workers = len(os.sched_getaffinity(0)) - except Exception: - pass - if max_num_workers is None: - cpu_count = os.cpu_count() - if cpu_count is not None: - # Use half number of CPUs - max_num_workers = cpu_count // 2 - - if max_num_workers is None: - max_num_workers = 1 - - for batch_size in (8, 16, 32, 64): - for num_workers in range(min(6, max_num_workers)): + for batch_size in (8, 32, 64): + for num_workers in range(1, 6): self._run_ind_worker_queue_test( - batch_size=batch_size, num_workers=num_workers + 1 + batch_size=batch_size, num_workers=num_workers ) class SetAffinityDataset(IterableDataset): + def __init__(self, expected_affinity=None): + self.expected_affinity = expected_affinity + def __iter__(self): - torch.randperm(1) - after = os.sched_getaffinity(0) - return iter(after) + affinity_mask = os.sched_getaffinity(0) + return iter(affinity_mask) + + +def _worker_set_affinity_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + dataset = worker_info.dataset + if ( + isinstance(dataset, SetAffinityDataset) + and dataset.expected_affinity is not None + ): + os.sched_setaffinity(0, [dataset.expected_affinity]) @unittest.skipIf( @@ -3539,19 +3537,14 @@ def test_set_affinity_in_worker_init(self): # Choose any expected_affinity = list(old_affinity)[-1] - def worker_set_affinity(_): - os.sched_setaffinity(0, [expected_affinity]) - - dataset = SetAffinityDataset() - - if not IS_WINDOWS and not IS_MACOS: - import multiprocessing as py_mp - - py_mp.set_start_method("fork", force=True) - + # Pass expected affinity through the dataset + dataset = SetAffinityDataset(expected_affinity=expected_affinity) dataloader = torch.utils.data.DataLoader( - dataset, num_workers=2, worker_init_fn=worker_set_affinity + dataset, + num_workers=2, + worker_init_fn=_worker_set_affinity_init, ) + for sample in dataloader: self.assertEqual(sample, [expected_affinity]) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 41ce5af6a28be..b6d825b1664f5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3532,11 +3532,11 @@ def make_non_contiguous_tensor_and_test(cnt): aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"): - ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u1 >= 0)" = arg1_1 >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None - ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None + ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None @@ -3573,11 +3573,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"): - ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u1 >= 0)" = arg1_1 >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None - ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None + ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None @@ -3632,21 +3632,21 @@ def func(x, y): aot_graphs, """\ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"): - ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None - ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0 - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + ge: "Sym(u2 >= 0)" = arg1_1 >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge = _assert_scalar = None + ge_1: "Sym(u3 >= 0)" = arg2_1 >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None - ge_4: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 - _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_4 = _assert_scalar_2 = None + ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 + _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense)) gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None _assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None _local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None - ge_5: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0 - _assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_5 = _assert_scalar_4 = None + ge_3: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0 + _assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_4 = None sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1)) gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None _assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None @@ -4068,10 +4068,10 @@ def func(x): self.assertExpectedInline( output, """\ - ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None - ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None + ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None @@ -4097,10 +4097,10 @@ def func(x): self.assertExpectedInline( output, """\ - ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None - ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None + ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None return (mul_5,)""", # noqa: B950 @@ -4283,11 +4283,11 @@ def make_non_contiguous_tensor_and_test(cnt): aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"): - ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u1 >= 0)" = arg1_1 >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None - ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None + ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None @@ -4319,11 +4319,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"): - ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge: "Sym(u1 >= 0)" = arg1_1 >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None - ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None + ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None diff --git a/test/test_linalg.py b/test/test_linalg.py index 9168964369920..7e3a1ebaa6f3a 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -10071,6 +10071,65 @@ def test_1_sized_with_0_strided(self, device, dtype): a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype) self.assertEqual(expect, res) + @onlyCUDA + def test_logaddexp_cpu_vs_cuda_complex(self, device): + # test logaddexp with complex values produce the same values (up to machine precision) on cpu and CUDA. + input_real = torch.tensor([0.052, -0.2115, 0.6913], dtype=torch.float64) + input_img = torch.tensor([-0.3229, -0.8374, 0.8391], dtype=torch.float64) + input_complex = torch.complex(input_real, input_img).cuda() + + other_real = torch.tensor([0.2550, 0.8769, -0.4884], dtype=torch.float64) + other_img = torch.tensor([0.6063, 0.4343, -1.4166], dtype=torch.float64) + other_complex = torch.complex(other_real, other_img).cuda() + + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + + torch.testing.assert_close(out_gpu.cpu(), out_cpu, rtol=1e-12, atol=1e-14) + + # test extreme cases (infty, -infty, and nan) are handled the same between cuda and cpu + input_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(float('inf'))) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + + input_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(-float('inf'))) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + + input_complex = torch.complex(torch.tensor(-float('inf')), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(float('inf'))) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + + input_complex = torch.complex(torch.tensor(-float('inf')), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(-float('inf')), torch.tensor(float('inf'))) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + + input_complex = torch.complex(torch.tensor(-float('inf')), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(-float('inf')), torch.tensor(2.)) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + + input_complex = torch.complex(torch.tensor(2.), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(float('inf'))) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + + input_complex = torch.complex(torch.tensor(float('nan')), torch.tensor(float('inf'))) + other_complex = torch.complex(torch.tensor(float('inf')), torch.tensor(float('inf'))) + out_gpu = torch.logaddexp(input=input_complex, other=other_complex) + out_cpu = torch.logaddexp(input=input_complex.cpu(), other=other_complex.cpu()) + self.assertEqual(out_gpu.cpu(), out_cpu) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index a8e9be4c972a1..7a6585f3b63a8 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -747,11 +747,13 @@ def create_inputs(B=None): @onlyCUDA @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @parametrize("M", [1, 32, 64]) - @parametrize("N", [1, 32, 64]) + @parametrize("N", [1, 64]) @parametrize("K", [1, 32, 64]) - @parametrize("batch_size", [None, 1, 32]) + @parametrize("batch_size", [None, 1]) + @parametrize("broadcast_self", [False, True]) + @parametrize("high_precision_self", [False, True]) @parametrize("backend", ["cublas", "cublaslt"]) - def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): + def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, broadcast_self, high_precision_self, backend): if torch.version.hip: msg = "accuracy regression in hipblas and hipblaslt in ROCm 7.0 for certain shapes" if input_dtype == torch.bfloat16 and N == 1 and K == 32 and batch_size: @@ -766,19 +768,21 @@ def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, bac device = "cuda" dtype = input_dtype with blas_library_context(backend): - def create_inputs(B=None): + def create_inputs(B, broadcast_self): if B is None: a = torch.randn(M, K, device=device, dtype=dtype) b = torch.randn(K, N, device=device, dtype=dtype) - c = torch.randn(M, N, device=device, dtype=dtype) + c_shape = (M, N) if not broadcast_self else (N) + c = torch.randn(c_shape, device=device, dtype=dtype) else: a = torch.randn(B, M, K, device=device, dtype=dtype) b = torch.randn(B, K, N, device=device, dtype=dtype) - c = torch.randn(B, M, N, device=device, dtype=dtype) + c_shape = (B, M, N) if not broadcast_self else (N) + c = torch.randn(c_shape, device=device, dtype=dtype) return a, b, c - a, b, c = create_inputs(batch_size) + a, b, c = create_inputs(batch_size, broadcast_self) a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32) @@ -800,21 +804,31 @@ def create_inputs(B=None): with self.assertRaises(RuntimeError): torch.addmm(c, a, b, out_dtype=output_dtype) else: + if c.dtype != output_dtype and high_precision_self: + c = c.to(output_dtype) if batch_size: out = torch.baddbmm(c, a, b, out_dtype=output_dtype) if output_dtype == torch.float32: baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) else: baseline = torch.baddbmm(c, a, b) + # test out variant + out_ten = torch.full_like(out, float("nan")) + torch.baddbmm(c, a, b, out_dtype=output_dtype, out=out_ten) else: out = torch.addmm(c, a, b, out_dtype=output_dtype) if output_dtype == torch.float32: baseline = torch.addmm(c_fp32, a_fp32, b_fp32) else: baseline = torch.addmm(c, a, b) + # test out variant + out_ten = torch.full_like(out, float("nan")) + torch.addmm(c, a, b, out_dtype=output_dtype, out=out_ten) self.assertEqual(out.dtype, output_dtype) + self.assertEqual(out_ten.dtype, output_dtype) torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(out_ten, out, atol=0, rtol=0) @onlyCUDA diff --git a/test/test_mps.py b/test/test_mps.py index 107aa3e4609d8..a84ac7d355169 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -3332,6 +3332,14 @@ def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None): helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1) helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2) + def test_repeat_interleave_offset(self): + # Regression test for https://github.com/pytorch/pytorch/issues/167924 + counts = torch.tensor([0, 1, 0], device="mps") + data = torch.arange(2, device="mps") + out_mps = data.repeat_interleave(counts[1:], dim=0) + out_cpu = data.cpu().repeat_interleave(counts.cpu()[1:], dim=0) + self.assertEqual(out_mps.cpu(), out_cpu) + def test_count_nonzero(self): def helper(dtype): n = [ @@ -5616,7 +5624,6 @@ def helper(n, c, h, w): helper(2, 8, 4, 5) # Test clamp_max - def test_clamp_max(self): def helper(n, c, h, w): cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) @@ -5708,6 +5715,27 @@ def helper(n, c, h, w): helper(2, 8, 4, 5) + def test_clamp_tensor_bounds_broadcasting(self): + def helper(input_shape, bound_shape): + cpu_x = torch.randn(input_shape, device="cpu", dtype=torch.float32, requires_grad=False) + mps_x = cpu_x.detach().clone().to("mps") + + cpu_min_t = torch.randn(bound_shape, device="cpu", dtype=cpu_x.dtype, requires_grad=False) + cpu_max_t = cpu_min_t + torch.rand_like(cpu_min_t).abs() + + mps_min_t = cpu_min_t.detach().clone().to("mps") + mps_max_t = cpu_max_t.detach().clone().to("mps") + + clamp_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t) + clamp_mps = torch.clamp(mps_x, min=mps_min_t, max=mps_max_t) + + self.assertEqual(clamp_mps.cpu(), clamp_cpu) + + helper((2, 3), (1, 2, 3)) + helper((4, 2, 3), (1, 2, 3)) + helper((2, 3), (2, 3)) + + def test_divmode(self): def helper(shape, rounding_mode): for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index 7dcddfb0f3906..24f60fdffd520 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -121,7 +121,7 @@ def size_impl(queue: OpaqueQueue) -> int: def size_impl_fake(q: OpaqueQueue) -> int: ctx = torch._custom_op.impl.get_ctx() u0 = ctx.new_dynamic_size() - torch._check_is_size(u0) + torch._check(u0 >= 0) return u0 torch.library.define( diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 531629b082d92..94d6ece0f6369 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -20,6 +20,7 @@ PLATFORM_SUPPORTS_MX_GEMM, PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, SM100OrLater, + SM120OrLater, SM89OrLater, SM90OrLater, with_tf32_off, @@ -53,6 +54,7 @@ _IS_SM8X = False + if TEST_CUDA: _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 @@ -736,6 +738,10 @@ def test_float8_scale(self, device) -> None: @parametrize("format", ["mxfp8"] + (["nvfp4", "mxfp4"] if torch.version.cuda else [])) def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): torch.manual_seed(42) + + if format == "mxfp4" and SM120OrLater: + raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300") + total_K = K # Alias for clarity, communicating this consists of several groups along this dim input_group_end_offsets = generate_jagged_offs( G, total_K, multiple_of=32, device="cuda" @@ -799,6 +805,10 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): @parametrize("format", ["mxfp8"] + (["nvfp4", "mxfp4"] if torch.version.cuda else [])) def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format): torch.manual_seed(42) + + if format == "mxfp4" and SM120OrLater: + raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300") + # Simulate 2d-3d grouped gemm `out = input @ weight.t()` # 2D inputs with groups along M, 3D weights. block_size = 32 @@ -1879,6 +1889,8 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping") if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") + if recipe == "mxfp4" and SM120OrLater: + raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300") device = "cuda" M, K, N = mkn @@ -2099,6 +2111,8 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg) @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: + if recipe == "mxfp4" and SM120OrLater: + raise unittest.SkipTest("MXFP4 on CUDA only supported on B200/B300") M, K, N = (1024, 512, 2048) BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32 BLOCK_SIZE_MN = 128 diff --git a/test/test_sparse.py b/test/test_sparse.py index e44e0d873553a..779ce21484d20 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -1391,9 +1391,9 @@ def run_test(nnz, size): # case nnz > size[d] run_test(tlen, tlen // 2) - @onlyCPU @coalescedonoff @dtypes(torch.double, torch.cdouble) + @dtypesIfMPS(torch.float32, torch.complex64) def test_mm(self, device, dtype, coalesced): def test_shape(di, dj, dk, nnz): x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced) diff --git a/test/test_torch.py b/test/test_torch.py index dce0ce53ac722..01c6fb39a5a2a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -259,7 +259,8 @@ def test_storage_setitem(self, device, dtype): def test_storage_use_count(self, device): a = torch.randn(10, device=device) prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) - self.assertEqual(prev_cf, 1) + # Two references: 'a' and the wrapper returned by untyped_storage() + self.assertEqual(prev_cf, 2) b = a.view(2, 5) self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1) @@ -9324,7 +9325,7 @@ class BadSubTensor: member_var = object() err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor" - with self.assertRaisesRegex(RuntimeError, err_msg): + with self.assertRaisesRegex(TypeError, err_msg): s0 = t0.as_subclass(BadSubTensor) # FIXME: Port to a test suite that better fits slicing @@ -10324,20 +10325,21 @@ def test_backward_hooks_traverse(self): @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") def test_tensor_dead_weak_ref(self): - x = torch.empty(2) + x = torch.ones(2) w_x = weakref.ref(x) - y = torch.empty(2) + y = torch.ones(2) y.grad = x del x x = w_x() - # Ideally, x would keep the tensor live. But CPython doesn't - # provide enough hooks to do this. So it will go dead and x - # will transmute into an undefined tensor. Not great, but the - # best we can do. + # x should keep the tensor live. This didn't happen in earlier PyTorch + # versions. del y - self.assertRaises(RuntimeError, lambda: x.sigmoid()) + self.assertEqual(2, x.sum()) + + del x + self.assertIsNone(w_x()) @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") def test_storage_dead_weak_ref(self): @@ -10345,16 +10347,9 @@ def test_storage_dead_weak_ref(self): w_x = weakref.ref(x) y = torch.tensor(x) del x - - x = w_x() - # Ideally, x would keep the storage live. But CPython doesn't - # provide enough hooks to do this. So it will go dead and x - # will transmute into storage with null StorageImpl. Not great, but the - # best we can do. + self.assertIsNotNone(w_x()) del y - - self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0]) - self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float()) + self.assertIsNone(w_x()) def test_tensor_resurrected_weak_ref(self): x = torch.empty(2) @@ -10415,6 +10410,31 @@ def callback(w): self.assertTrue(called) + def test_storage_thread_safety(self): + import threading + from concurrent.futures import ThreadPoolExecutor + + NUM_ITERS = 10 + NUM_THREADS = 4 + + # Concurrent calls to tensor.untyped_storage() + def access_untyped_storage(tensor, barrier): + barrier.wait() + return weakref.ref(tensor.untyped_storage()) + + for i in range(NUM_ITERS): + tensor = torch.tensor([1.0, 2.0, 3.0]) + barrier = threading.Barrier(NUM_THREADS) + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = [ + executor.submit(access_untyped_storage, tensor, barrier) + for _ in range(NUM_THREADS) + ] + + # Check that all the storages returned were the same + for future in futures: + self.assertEqual(future.result()(), tensor.untyped_storage()) + # FIXME: move to test_linalg @torch.inference_mode() def test_bmm_multithreaded(self): diff --git a/test/test_utils.py b/test/test_utils.py index f6bdc156c122e..ab2f133ca3f7e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -53,8 +53,10 @@ # sharding on sandcastle. This line silences flake warnings load_tests = load_tests # noqa: PLW0127 -HAS_CUDA = torch.cuda.is_available() - +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) +TEST_GPU = torch.xpu.is_available() or torch.cuda.is_available() from torch.testing._internal.common_utils import run_tests, TestCase @@ -302,24 +304,24 @@ def run_fn(input): self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) - @unittest.skipIf(not HAS_CUDA, "No CUDA") - def test_checkpoint_rng_cuda(self): + @unittest.skipIf(not TEST_GPU, "No accelerator") + def test_checkpoint_rng_gpu(self): for _ in range(5): - inp = torch.randn(20000, device="cuda").requires_grad_() + inp = torch.randn(20000, device=device_type).requires_grad_() phase1 = torch.nn.Dropout() phase2 = torch.nn.Dropout() def run_fn(input): return phase2(input) - state = torch.cuda.get_rng_state() + state = torch.get_device_module(device_type).get_rng_state() out = phase1(inp) out = checkpoint(run_fn, out, use_reentrant=True) out.sum().backward() grad_with_checkpointing = inp.grad - torch.cuda.set_rng_state(state) + torch.get_device_module(device_type).set_rng_state(state) inp.grad = None @@ -330,9 +332,9 @@ def run_fn(input): self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) - @unittest.skipIf(not HAS_CUDA, "No CUDA") + @unittest.skipIf(not TEST_GPU, "No accelerator") def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self): - inp = torch.randn(2, device="cuda").requires_grad_() + inp = torch.randn(2, device=device_type).requires_grad_() layer = torch.nn.Dropout() def run_fn(input): @@ -435,10 +437,10 @@ def run_fn2(tensor1, tensor2): out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True) out.sum().backward() - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") + @unittest.skipIf(not TEST_GPU, "No accelerator") def test_checkpointing_without_reentrant_early_free(self): # I don't know how to check if the temporary saved variable buffer - # get de-allocated directly. So using cuda memory usage as a proxy + # get de-allocated directly. So using GPU memory usage as a proxy def _do_test(fn, should_free): stats: list[int] = [] @@ -449,8 +451,8 @@ def track(x, idx): # emptied at each step) def hook(_unused): self.assertEqual(len(stats), idx) - torch.cuda.synchronize() - stats.append(torch.cuda.memory_allocated()) + torch.accelerator.synchronize() + stats.append(torch.accelerator.memory_allocated()) if idx > 0: if should_free: self.assertLess(stats[idx], stats[idx - 1]) @@ -475,7 +477,7 @@ def test_fn(x): return stats - x = torch.zeros(10, device="cuda", requires_grad=True) + x = torch.zeros(10, device=device_type, requires_grad=True) x.grad = torch.zeros_like(x) # In a regular backward, buffers get eagerly freed @@ -505,8 +507,8 @@ def test_fn(x): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_get_device_states_recursive(self): inp = { - "foo": torch.rand(10, device="cuda:0"), - "bar": [torch.rand(10, device="cuda:1")], + "foo": torch.rand(10, device=f"{device_type}:0"), + "bar": [torch.rand(10, device=f"{device_type}:1")], } device_ids, device_states = get_device_states(inp) self.assertEqual(2, len(device_ids)) @@ -522,42 +524,42 @@ def test_infer_device_state_recursive_meta(self): self.assertEqual("meta", device_type) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - def test_infer_device_state_recursive_multi_cuda(self): - # Check that no warning is issued for either cuda:0, cuda:1 or - # cuda:0, cuda:0 cases since they are both the same device type + def test_infer_device_state_recursive_multi_gpu(self): + # Check that no warning is issued for either gpu:0, gpu:1 or + # gpu:0, gpu:0 cases since they are both the same device type inp = { - "foo": torch.rand(10, device="cuda:0"), - "bar": [torch.rand(10, device="cuda:1")], + "foo": torch.rand(10, device=f"{device_type}:0"), + "bar": [torch.rand(10, device=f"{device_type}:1")], } with warnings.catch_warnings(): warnings.simplefilter("error") - device_type = _infer_device_type(inp) - self.assertEqual("cuda", device_type) + _device_type = _infer_device_type(inp) + self.assertEqual(device_type, _device_type) inp = { - "foo": torch.rand(10, device="cuda:0"), - "bar": [torch.rand(10, device="cuda:0")], + "foo": torch.rand(10, device=f"{device_type}:0"), + "bar": [torch.rand(10, device=f"{device_type}:0")], } with warnings.catch_warnings(): warnings.simplefilter("error") - device_type = _infer_device_type(inp) - self.assertEqual("cuda", device_type) - # Check that a warning is issued for cuda:0, meta and that it includes + _device_type = _infer_device_type(inp) + self.assertEqual(device_type, _device_type) + # Check that a warning is issued for gpu:0, meta and that it includes # device type information inp = { - "foo": torch.rand(10, device="cuda:0"), + "foo": torch.rand(10, device=f"{device_type}:0"), "bar": [torch.rand(10, device="meta")], } with warnings.catch_warnings(record=True) as w: - device_type = _infer_device_type(inp) - self.assertEqual("cuda", device_type) + _device_type = _infer_device_type(inp) + self.assertEqual(device_type, _device_type) self.assertEqual(len(w), 1) warning_msg = str(w[-1].message) self.assertTrue( "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices" in warning_msg ) - self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg) - self.assertTrue("first device type: cuda" in warning_msg) + self.assertTrue(f"Device types: ['{device_type}', 'meta']" in warning_msg) + self.assertTrue(f"first device type: {device_type}" in warning_msg) class TestDataLoaderUtils(TestCase): @@ -604,7 +606,7 @@ def test_single_drop(self): self.assertEqual(len(list(dataiter)), 1) @unittest.skip( - "FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN" + "FIXME: Intermittent GPU out-of-memory error on Windows and time-out under ASAN" ) def test_multi_keep(self): dataloader: DataLoader = DataLoader( @@ -861,27 +863,33 @@ def test_get_default_device(self): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_get_default_device_more(self): try: - torch.set_default_device("cuda") + torch.set_default_device(device_type) self.assertEqual(torch.get_default_device(), torch.tensor([]).device) torch.set_default_device(None) - torch.set_default_device("cuda") - torch.cuda.set_device("cuda:1") + torch.set_default_device(device_type) + torch.get_device_module(device_type).set_device(f"{device_type}:1") + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.accelerator.set_device_index(1) self.assertEqual(torch.get_default_device(), torch.tensor([]).device) torch.set_default_device(None) - torch.set_default_device("cuda:1") + torch.set_default_device(f"{device_type}:1") self.assertEqual(torch.get_default_device(), torch.tensor([]).device) torch.set_default_device(None) - torch.set_default_device("cuda:1") - with torch.device("cuda:0"): - self.assertEqual(torch.get_default_device(), torch.device("cuda", 0)) + torch.set_default_device(f"{device_type}:1") + with torch.device(f"{device_type}:0"): + self.assertEqual( + torch.get_default_device(), torch.device(f"{device_type}", 0) + ) torch.set_default_device("cpu") self.assertEqual(torch.get_default_device(), torch.device("cpu")) - with torch.device("cuda:0"): - self.assertEqual(torch.get_default_device(), torch.device("cuda", 0)) + with torch.device(f"{device_type}:0"): + self.assertEqual( + torch.get_default_device(), torch.device(f"{device_type}", 0) + ) self.assertEqual(torch.get_default_device(), torch.device("cpu")) finally: diff --git a/third_party/xpu.txt b/third_party/xpu.txt index a5031de150288..f05ce60393d66 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -9aac5a1ddf50d75f929d572df51bb368b32da14e +1e69f40b3c03492eb3dd7e03462a5566f29674d3 diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7194ef034bb5a..e9b58b9ce71eb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -4,7 +4,6 @@ # ruff: noqa: F401 from collections.abc import Callable, Iterable, Iterator, Sequence -from contextlib import AbstractContextManager from enum import Enum, IntEnum from pathlib import Path from types import EllipsisType @@ -231,8 +230,8 @@ ${dtype_class_hints} class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp -def DisableTorchFunction() -> AbstractContextManager: ... -def DisableTorchFunctionSubclass() -> AbstractContextManager: ... +def DisableTorchFunction(): ... +def DisableTorchFunctionSubclass(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp strided: layout = ... diff --git a/torch/__init__.py b/torch/__init__.py index e39e50a1f8409..ba8f60f5fffe0 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -33,7 +33,11 @@ TypeVar as _TypeVar, Union as _Union, ) -from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs +from typing_extensions import ( + deprecated as _deprecated, + ParamSpec as _ParamSpec, + TypeIs as _TypeIs, +) # As a bunch of torch.packages internally still have this check @@ -1735,7 +1739,10 @@ def _check(cond, message=None): # noqa: F811 _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type] -# TODO add deprecation annotation +@_deprecated( + "_check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. \ + Use _check(i >= 0) instead." +) def _check_is_size(i, message=None, *, max=None): """Checks that a given integer is a valid size (i.e., is non-negative). You should use this over ``_check(i >= 0)`` because it can prevent diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index e328422ec5e66..84978f0066712 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -422,7 +422,17 @@ def deserialize_torch_artifact( buffer = io.BytesIO(serialized) buffer.seek(0) # weights_only=False as we want to load custom objects here (e.g. ScriptObject) - artifact = torch.load(buffer, weights_only=False) + try: + artifact = torch.load(buffer, weights_only=True) + except Exception as e: + buffer.seek(0) + artifact = torch.load(buffer, weights_only=False) + log.warning( + "Fallback to weights_only=False succeeded. " + "Loaded object of type %s after initial failure: %s", + type(artifact), + exc_info=e, + ) assert isinstance(artifact, (tuple, dict)) return artifact diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e3e5913be7d76..88f203421cc1c 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -239,7 +239,10 @@ def reduction_combine( if reduction_type in ("min", "max"): return f"{reduction_type}_propagate_nan({var}, {next_value})" if reduction_type == "welford_reduce": - return f"welford_combine({var}, {next_value})" + if helper_val: + return f"welford_combine({var}, {next_value}, &{helper_val})" + else: + return f"welford_combine({var}, {next_value})" if reduction_type == "welford_combine": if isinstance(next_value, tuple): mean, m2, weight = next_value @@ -2194,10 +2197,8 @@ def need_use_acc_helper(self, reduction_type, dtype, use_scalar): # sum and welford # Note: using helper has non-negligible impact on performance - # keep the original behavior for welford_reduce - # acc helper is not used for scalar welford_reduce if reduction_type == "welford_reduce": - return not use_scalar + return True # TODO add supports for more data types when needed if reduction_type == "sum" and dtype == torch.float: @@ -2323,9 +2324,15 @@ def reduction(self, dtype, src_dtype, reduction_type, value): reduction_size = functools.reduce( operator.mul, self.ranges[self.reduction_depth :] ) - helper_val = self.cascade_helper_cse.generate( - self.compute, f"reduction {reduction_key}", write=False - ) + # use welford_helper/cascade_helper for vec kernel + if reduction_type == "welford_reduce": + helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + else: + helper_val = self.cascade_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) # rename the helper variable to distinguish it from vectorized version scalar_helper_val = f"scalar_{helper_val}" self._use_acc_helper( @@ -3092,19 +3099,16 @@ def reduction(self, dtype, src_dtype, reduction_type, value): if self.ranges[self.tiling_idx] % self.tiling_factor else sympy.Integer(0) ) - # scalar helper for scalar sum is also needed when vec kernel is included - # Note: is it different from welford reduction as welford reduction of scalar version - # does not need helper, and the helper needs the information of reduction size to initialize - if reduction_type == "sum": - scalar_helper_val = f"scalar_{helper_val}" - self._use_acc_helper( - reduction_type, - acc, - scalar_helper_val, - reduction_size, - dtype, - use_scalar=True, - ) + # scalar helper for scalar welford_reduce/sum is also needed when vec kernel is included + scalar_helper_val = f"scalar_{helper_val}" + self._use_acc_helper( + reduction_type, + acc, + scalar_helper_val, + reduction_size, + dtype, + use_scalar=True, + ) self._use_acc_helper( reduction_type, acc, helper_val, helper_vec_range, dtype ) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 65e8f88b1c425..2ad02ca97a54b 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2819,8 +2819,6 @@ def process_node_vars( bad_size_additional_tiling_penalty = 1.025 good_size_tiling_penalty = 1.005 - total_uncoalesced = sum(coalesce_analysis.uncoalesced_addrs.values()) - def score_mod(t): score_factor = 1.0 for tile_size in t[0].tiling.values(): @@ -2829,19 +2827,12 @@ def score_mod(t): else: score_factor = score_factor / good_size_tiling_penalty - # Add uncoalesced memory score to prevent small coalesced benefits - # from dominating large amounts of uncoalesced memory - uncoalesced_penalty = total_uncoalesced * 0.05 - - return -(t[0].score + uncoalesced_penalty) * score_factor + return -t[0].score * score_factor # apply penalty for longer tilings that dont increase score much for cand, tiling_score in sorted(tilings, key=score_mod): - if ( - cls.tiling_is_compatible( - node_schedule, pointwise_numel, reduction_numel, cand.tiling - ) - or cand.tiling == default_tiling + if cls.tiling_is_compatible( + node_schedule, pointwise_numel, reduction_numel, cand.tiling ): # we always include default reduction numel == 1, dont include tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 74a58acb84ff3..e6f85204a2c14 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -196,16 +196,9 @@ def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: if "all_gather_into_tensor_out" in py_kernel_name: args = args[1:] + args[0] - try: - with torch.distributed._time_estimator( - group=pg, device=device - ) as time_estimator: - w = fn(*args, **kwargs) - torch.ops._c10d_functional.wait_tensor.default(w) - except Exception as e: - # NCCL estimator can fail - log.info(e) # noqa: G200 - return None + with torch.distributed._time_estimator(group=pg, device=device) as time_estimator: + w = fn(*args, **kwargs) + torch.ops._c10d_functional.wait_tensor.default(w) est_time_us = time_estimator.estimated_time # -1000 constant is NCCL return in case of error during estimations. @@ -359,7 +352,6 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: def estimate_nccl_collective_runtime_from_fx_node( fx_node: torch.fx.Node, override_size: Optional[int] = None, - # TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix. use_nccl_estimator: bool = True, ) -> float: """ @@ -398,6 +390,20 @@ def estimate_nccl_collective_runtime_from_fx_node( def _nccl_estimate() -> Optional[float]: # TODO: Refactor with estimate_nccl_collective_runtime_nccl_estimator + from torch.distributed.distributed_c10d import ( + _get_pg_default_device, + _resolve_process_group, + ) + + pg = _resolve_process_group(group_name) + if torch.distributed.distributed_c10d.get_backend(pg) == "fake": + # nccl estimator requires real process group + return None + + device = _get_pg_default_device(pg) + backend = pg._get_backend(device) + if not backend.supports_time_estimate: + return None flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) @@ -421,13 +427,6 @@ def to_real_tensor(e: Any) -> Any: flat_args = [to_real_tensor(a) for a in flat_args] real_args, real_kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec) - from torch.distributed.distributed_c10d import _resolve_process_group - - pg = _resolve_process_group(group_name) - if torch.distributed.distributed_c10d.get_backend(pg) == "fake": - # nccl estimator requires real process group - return None - fn = fx_node.target assert isinstance(fn, torch._ops.OpOverload) with torch.distributed._time_estimator(group=pg) as time_estimator: @@ -441,7 +440,7 @@ def to_real_tensor(e: Any) -> Any: est_time_ms = est_time_us / 1e3 return est_time_ms - if torch.distributed.is_nccl_available() and use_nccl_estimator: + if use_nccl_estimator: est_time_ms = _nccl_estimate() if est_time_ms is not None: return est_time_ms diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b6796d2b7ce38..46ca60483828d 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -2537,16 +2537,19 @@ def _extract_inputs_from_exported_gm( fake_inputs = [ node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder" ] - # Replace non-tensor (constant) inputs with Nones, since these are not being - # used anyways by the graph - fake_inputs = [ - inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs - ] + + if not config.fx_wrapper: + # Replace non-tensor inputs with Nones + # constant scalars embedded in the graph + # symbolic scalars (symint) are not supported in non-fx_wrapper mode + fake_inputs = [ + inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs + ] if any(v is not None for v in fake_inputs): # Validate devices before switching to fake tensors. for idx, fi, i in zip(count(), fake_inputs, example_inputs_): - if fi is not None: + if fi is not None and isinstance(fi, torch.Tensor): assert isinstance(i, torch.Tensor) if fi.device != i.device: raise ValueError( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5152915b7d0be..2f28ce551b103 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -421,6 +421,10 @@ def prologue_fusion_enabled() -> bool: None ) +bucket_all_reduces_fx: Literal["none", "all"] = "none" +# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used +bucket_all_reduces_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None + # runtime estimation function for ops # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle estimate_op_runtime = "default" diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 20cd5ca9a8888..98280b5af783c 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -536,9 +536,14 @@ def expired(self) -> bool: if self.extra_ref_check is not None and not self.extra_ref_check(): return False - # if extra_ref_check is not None we expect an additional reference stor_count = torch._C._storage_Use_Count(self.ref.cdata) - return (stor_count - (self.extra_ref_check is not None)) == 0 + if self.extra_ref_check is not None: + # if extra_ref_check is not None we expect two additional references: + # - one from the Python storage object + # - one from the cached Tensor + stor_count -= 2 + assert stor_count >= 0 + return stor_count == 0 def __repr__(self) -> str: if self.ref is None or self.ref.expired(): @@ -1439,7 +1444,15 @@ def check_refcount(i: int) -> bool: self_loc = self_ref() if self_loc is None: return False - return self_loc.get_output_refcount(i) == 2 + refcount = self_loc.get_output_refcount(i) + # pyrefly: ignore + if self_loc.cached_tensor_outputs[i]._use_count() > 1: + # c10::Tensor may also holds one reference count + assert refcount >= 3 + return refcount == 3 + else: + assert refcount >= 2 + return refcount == 2 check = functools.partial(check_refcount, i=i) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 958a52fcdf510..e0362f2aaafd4 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -222,6 +222,18 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): ) collectives_bucketing = True + if config.bucket_all_reduces_fx != "none": + from torch._inductor.fx_passes.bucketing import bucket_all_reduce + + GraphTransformObserver(gm, "bucket_all_reduce").apply_graph_pass( + lambda graph: bucket_all_reduce( + graph.owning_module, + config.bucket_all_reduces_fx_bucket_size_determinator, + config.bucket_all_reduces_fx, # type: ignore[arg-type] + ) + ) + collectives_bucketing = True + # Fx all_gather bucketing introduces mutation op # Keeping it in the end to keep invariant of functional graph for previous passes. if config.bucket_all_gathers_fx != "none": diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 67e0174443882..72d8383d2b812 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1435,7 +1435,9 @@ def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: strides = V.graph.sizevars.stride_hints( j, reduction_vars, list(ranges1.keys()) ) - outer = all(s > 1 for s in strides) + # A 0 stride does not make a reduction contiguous. + # This can happen when the reduction ranges contains a 1. + outer = all(s == 0 or s > 1 for s in strides) if outer: num_outer += 1 else: diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index 4eede8631e9ce..f48f351ce823a 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -38,20 +38,7 @@ def __init__(self, kernel: CompiledKernel) -> None: # pyrefly: ignore [missing-attribute] self.name = kernel.src.fn.__name__ # pyrefly: ignore [missing-attribute] - if "hsaco" in kernel.asm: - # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm["hsaco"] - self.is_rocm = True - # pyrefly: ignore [missing-attribute] - elif "cubin" in kernel.asm: - # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm["cubin"] - self.is_rocm = False - else: - raise RuntimeError( - "Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm" - ) - + self.cubin_raw = kernel.asm.get("cubin", None) # pyrefly: ignore [missing-attribute] self.cubin_path = kernel._cubin_path @@ -258,42 +245,12 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher + # Add a None if triton wants extra parameters for scratch spaces arg_tys = self.arg_tys - - if self.is_rocm: - # ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both - # global_scratch and profile_scratch parameters in the kernel signature, - # even when the kernel doesn't use them (i.e., when has_*_scratch is False). - # - # This differs fundamentally from CUDA, where these parameters are only - # present in the signature if the corresponding has_*_scratch flag is True. - # - # The flags indicate whether memory will be allocated/used: - # - has_global_scratch: Whether global scratch workspace is needed - # - has_profile_scratch: Whether profiling instrumentation is enabled - # - # However, regardless of flag values, we MUST always pass both parameters - # to match the HIP kernel ABI. Passing None is safe: - # - # - If scratch is not needed (has_*_scratch=False or scratch_size=0): - # The None becomes nullptr, which the kernel never dereferences - # - # - If scratch is needed (has_*_scratch=True and scratch_size>0): - # The None becomes nullptr initially, but the HIP runtime intercepts - # the kernel launch, allocates the required scratch memory based on - # kernel metadata, and replaces the nullptr with a valid pointer before - # the kernel actually executes - # - # Not passing both parameters causes segmentation faults because the kernel - # expects them at specific positions in the argument array. - arg_tys = arg_tys + "OO" - args = (*args, None, None) - - else: - for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: - if has_scratch: - arg_tys = arg_tys + "O" - args = (*args, None) + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) # pyrefly: ignore [bad-argument-type] assert len(args) == len(arg_tys) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2a83f7b59117d..d5851eeceeb24 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -22,7 +22,6 @@ import torch from torch._dynamo.utils import counters, set_feature_use -from torch._environment import is_fbcode from torch._inductor import metrics from torch._prims_common import compute_required_storage_length from torch.utils._debug_mode import get_active_debug_mode @@ -1609,8 +1608,9 @@ def can_statically_launch( return None def check_can_launch() -> StaticallyLaunchedCudaKernel: - if triton_meta.get("device_type") not in ("cuda", "hip"): - raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device") + if triton_meta.get("device_type") != "cuda": + # Only cuda kernels + raise CannotStaticallyLaunchKernel("Non-cuda device") if torch._inductor.config.cpp_wrapper: # If we're running with cpp wrapper, it doesn't @@ -1636,11 +1636,10 @@ def check_can_launch() -> StaticallyLaunchedCudaKernel: "static launch does not support launch attributes" ) - binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin" cubin_location = os.path.join( triton_cache_dir(triton_meta.get("device", 0)), triton_hash_to_path_key(kernel.hash), - f"{kernel.src.fn.__name__}.{binary_ext}", + f"{kernel.src.fn.__name__}.cubin", ) if not os.path.exists(cubin_location): @@ -1672,11 +1671,10 @@ def reload_cubin_path(self): When loading from cache on disk, we want to reload cubin files from their appropriate location on disc. """ - binary_ext = "hsaco" if torch.version.hip else "cubin" cubin_location = os.path.join( triton_cache_dir(self.compile_meta.get("device", 0)), triton_hash_to_path_key(self.kernel.hash), - f"{self.kernel.name}.{binary_ext}", + f"{self.kernel.name}.cubin", ) if not os.path.exists(cubin_location): if self.kernel.cubin_raw is not None: @@ -2470,9 +2468,8 @@ def total_numel() -> int: rnumels[prefix] *= 2 if num_warps is None: - if reduction_hint == ReductionHint.INNER and not is_fbcode(): - # r is contiguous, so ensure that each thread has 8 elements for - # vectorized loads, assuming bf16/fp16 + if reduction_hint == ReductionHint.INNER: + # r is contiguous, ensure at least 8 elements per thread # xblock is usually 1-2, default to giving each thread more work num_warps = r // 128 else: @@ -2942,7 +2939,7 @@ def outer_config_opt(): ) contiguous_config = make_config( - 2 if rnumel <= 2048 and not is_fbcode() else 1, # 1024 or less is persistent + 2 if rnumel <= 2048 else 1, # 1024 or less is persistent min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) @@ -2955,7 +2952,7 @@ def outer_config_opt(): outer_config = make_config(64, 8, register_intensive=register_intensive) # TODO (paulzhan): Test heuristic on AMD and internal testing # for correctness - if not torch.version.hip and not is_fbcode(): + if not torch.version.hip: outer_config = outer_config_opt() configs = [] diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 4a4efdccf4b38..5b394b9ea9914 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -145,38 +145,6 @@ def indexing_div_rep( return None -def find_broadcast_var( - index: sympy.Expr, var_ranges: dict[sympy.Expr, int] -) -> Optional[sympy.Expr]: - """ - Try to find the variable that this index is broadcast over. - A broadcast pattern is one where consecutive values of a variable - access the same memory location (e.g., x // 10). - """ - # Approximate analysis by evaluating at 1 and 0 - variables: dict[sympy.Symbol, int] = {} - for v in index.free_symbols: - if v in var_ranges: - variables[v] = 0 - else: - variables[v] = get_hint(v) - - zero_index = sympy_subs(index, variables) - for v in var_ranges.keys(): - variables[v] = 1 - try: - new_val = sympy_subs(index, variables) - except ZeroDivisionError: - loop_tiling_log.info("zero division error %s %s", index, variables) - continue - # Broadcast means the value doesn't change when the variable increments - if new_val == zero_index: - return v - variables[v] = 0 - - return None - - def find_coalesced_var( index: sympy.Expr, var_ranges: dict[sympy.Expr, int] ) -> Optional[sympy.Expr]: @@ -600,12 +568,11 @@ def remove_identity(expr: sympy.Expr) -> sympy.Expr: return fused_out -def get_score( - addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int], buf_names: OrderedSet[str] -) -> int: +def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int: """ - Score addr according to its approximate size. + Score addr according to its approximate size """ + # TODO - deduplicate with candidate_tilings var_sizes = [] for v in addr.free_symbols: @@ -620,15 +587,6 @@ def get_score( ) -def try_get_buf_size(buf_name: str) -> Optional[int]: - buf = V.graph.try_get_buffer(buf_name) - if not buf: - return None - return V.graph.sizevars.atomically_apply_size_hint( - sympy_product(buf.get_size()), fallback=config.unbacked_symint_fallback - ) - - def get_hint(v: Union[sympy.Expr, int]) -> int: if isinstance(v, int): return v @@ -654,8 +612,6 @@ class CoalesceVarAnalysis: # TODO: separate into dataclass that olds mem, dtype, is_write coalesced_by_var: dict[sympy.Expr, int] - uncoalesced_addrs: dict[sympy.Expr, int] - norm_read_writes: FusedNormalizedReadsWrites suggested_split: Optional[VarTiling] = None @@ -701,40 +657,28 @@ def analyze_memory_coalescing( if indirect_expr: continue - size = get_score(memory_expr, var_ranges, buf_names) - + size = get_score(memory_expr, var_ranges) if size == 0: continue maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges) - # while broadcasting vars are not technically coalesced, - # accesses at least stay in cache, so they provide most of the benefit. - # treat the same for now. - if maybe_coalesced_var is None: - maybe_coalesced_var = find_broadcast_var(memory_expr, var_ranges) - total_score = 0 + byte_multipler = 0 for buf_name in buf_names: - if (buf := V.graph.try_get_buffer(buf_name)) and ( - buf_size := try_get_buf_size(buf_name) - ): - # constrain by buf size since we'll read at most that many elements - # score could be more through either masking or by broadcasting (e.g. x // 16) - total_score += min(buf_size, size) * buf.dtype.itemsize + if buf := V.graph.try_get_buffer(buf_name): + byte_multipler += buf.dtype.itemsize # coalesced writes more important - total_score *= 1 if is_read else 2 + byte_multipler *= 1 if is_read else 2 if maybe_coalesced_var: - coalesced_by_var[maybe_coalesced_var] += total_score + coalesced_by_var[maybe_coalesced_var] += size * byte_multipler else: - uncoalesced_addrs[memory_expr] += total_score + uncoalesced_addrs[memory_expr] += size * byte_multipler if not uncoalesced_addrs: return CoalesceVarAnalysis( - coalesced_by_var=coalesced_by_var, - uncoalesced_addrs=uncoalesced_addrs, - norm_read_writes=norm_read_writes, + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes ) # map from var -> tiling -> total_score @@ -778,9 +722,7 @@ def analyze_memory_coalescing( if len(tiling_scores) == 0: return CoalesceVarAnalysis( - coalesced_by_var=coalesced_by_var, - uncoalesced_addrs=uncoalesced_addrs, - norm_read_writes=norm_read_writes, + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes ) best_tiling: Optional[tuple[sympy.Expr, int]] = None @@ -794,9 +736,7 @@ def analyze_memory_coalescing( if best_tiling is None: return CoalesceVarAnalysis( - coalesced_by_var=coalesced_by_var, - uncoalesced_addrs=uncoalesced_addrs, - norm_read_writes=norm_read_writes, + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes ) # TODO - for strictly pointwise fusions, @@ -805,7 +745,6 @@ def analyze_memory_coalescing( # TODO - could also prefer index var splits to reduction, better tested return CoalesceVarAnalysis( coalesced_by_var=coalesced_by_var, - uncoalesced_addrs=uncoalesced_addrs, norm_read_writes=norm_read_writes, suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score), ) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index ec28bfbd825e2..fa43af2701171 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -52,7 +52,26 @@ "MemRecordsAcc", ] -from contextlib import ContextDecorator +try: + # Available in Python >= 3.2 + from contextlib import ContextDecorator as _ContextDecorator +except ImportError: + import functools + + class _ContextDecorator: # type: ignore[no-redef] + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + + def __call__(self, func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapped # global python state - whether profiler is currently enabled @@ -209,12 +228,12 @@ def __init__( FutureWarning, stacklevel=2, ) - self.use_device: str | None = "cuda" + self.use_device: Optional[str] = "cuda" else: self.use_device = use_device # TODO Consider changing _function_events into data structure with size cap - self._function_events: EventList | None = None - self._old_function_events: EventList | None = None + self._function_events: Optional[EventList] = None + self._old_function_events: Optional[EventList] = None # Function event processing is done lazily self._needs_processing = False self.entered = False @@ -229,7 +248,7 @@ def __init__( if experimental_config is None: experimental_config = _ExperimentalConfig() self.experimental_config = experimental_config - self.kineto_results: _ProfilerResult | None = None + self.kineto_results: Optional[_ProfilerResult] = None self.profiling_start_time_ns = 0 self.profiling_end_time_ns = 0 self._stats = _ProfilerStats() @@ -725,7 +744,8 @@ def createFunctionEventForMemoryEvents(evt): return all_function_events -class record_function(ContextDecorator): +# pyrefly: ignore [invalid-inheritance] +class record_function(_ContextDecorator): """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. Label will only appear if CPU activity tracing is enabled. @@ -764,13 +784,16 @@ class record_function(ContextDecorator): """ - def __init__(self, name: str, args: str | None = None): + def __init__(self, name: str, args: Optional[str] = None): self.name: str = name - self.args: str | None = args + self.args: Optional[str] = args # Whether or not we should run record function's end callbacks when exiting. self.run_callbacks_on_exit: bool = True + # TODO: TorchScript ignores standard type annotation here + # self.record: Optional["torch.classes.profiler._RecordFunction"] = None self.record = torch.jit.annotate( - Optional[torch.classes.profiler._RecordFunction], + # pyrefly: ignore [not-a-type] + Optional["torch.classes.profiler._RecordFunction"], None, ) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 7d3f9ecc4d007..adf1c8c4c4d20 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -398,36 +398,27 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { // weak_use_count() adds 1 if use_count is non-zero TORCH_CHECK( - a->cdata->weak_use_count() == 1, + a->cdata.weak_use_count() == 1, "Expected no weakrefs to t1's Tensor object but got ", - a->cdata->weak_use_count() - 1); + a->cdata.weak_use_count() - 1); TORCH_CHECK( - b->cdata->weak_use_count() == 1, + b->cdata.weak_use_count() == 1, "Expected no weakrefs to t2's Tensor object but got ", - b->cdata->weak_use_count() - 1); + b->cdata.weak_use_count() - 1); + + // NB: Creating local copies of *both* Tensors here ensures that they each + // hold a strong reference to their PyObject. This avoids having to fix up + // reference counts when we swap the PyObject slots below. + at::Tensor tmp_a = a->cdata; + at::Tensor tmp_b = b->cdata; // Swap the Tensor Impl - c10::MaybeOwned tmp = a->cdata; - - // The TensorImpls contain PyObjectSlots that have a reference to the PyObject - // associated with the TensorImpl. Swap this field as well. - std::optional mb_obj_a = - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - std::optional mb_obj_b = - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - TORCH_INTERNAL_ASSERT( - mb_obj_a.has_value() && mb_obj_b.has_value(), - "Both tensors should have PyObjects tagged by the current python interpreter"); - TORCH_CHECK(mb_obj_a.value() == a_); - TORCH_CHECK(mb_obj_b.value() == b_); - - a->cdata = b->cdata; - b->cdata = tmp; - - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_); - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_); + a->cdata = tmp_b; + b->cdata = tmp_a; + + // Fix up the PyObjects associated with each TensorImpl + a->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(a_); + b->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(b_); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -2159,7 +2150,7 @@ PyObject* initModule() { #ifdef USE_CUDA torch::cuda::initModule(module); #endif -#if defined(USE_CUDA) +#if defined(USE_CUDA) && !defined(USE_ROCM) ASSERT_TRUE(StaticCudaLauncher_init(module)); #endif #ifdef USE_MPS diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index 974f95999f17b..7f36d88bdaa32 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -45,7 +45,9 @@ struct ConcretePyInterpreterVTable final std::string name() const override; void incref(PyObject* pyobj) const override; - void decref(PyObject* pyobj, bool has_pyobj_slot) const override; + void decref(PyObject* pyobj) const override; + bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override; + size_t refcnt(PyObject* pyobj) const override; // TODO: Need to make this work for StorageImpl too. I imagine I'll want to // operate upon a PyObjectSlot rather than a TensorImpl @@ -235,53 +237,13 @@ py::object torchDispatchFromTensorImpl( TorchFunctionName::TorchDispatch)); } -// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] -// Before calling PyInterpreter::decref, we must statically know if the -// pyobj has a PyObjectSlot or not. -// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection -// - If it does not have a PyObjectSlot, we can freely decref -// One alternative to this is using PyObject_IsInstance -// to get at this information. However, we don't want to risk an incorrect -// `__instancecheck__` changing the semantics here. -void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot) - const { +void ConcretePyInterpreterVTable::decref(PyObject* pyobj) const { // Leak the pyobj if not initialized. This can happen if we are running // exit handlers that are destructing tensors with residual (owned) // PyObjects stored in them. if (!Py_IsInitialized()) return; - pybind11::gil_scoped_acquire gil; - // Two possibilities: - // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or - // Storage. Then we must be careful about PyObject resurrection (see - // THPVariable_clear). - // 2. We are decref-ing some other Python object. We don't do - // PyObject resurrection on non-Tensors, so we just carry on as usual - if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) { - if (THPVariable_Check(pyobj)) { - // It's still alive! This can happen if a weak ref resurrected - // the PyObject without flipping ownership. At this point it is - // too late to rescue the object, so just stub out the PyObject - // so that it fails on subsequent uses. Don't raise an error here; - // you're probably in a destructor. - TORCH_WARN( - "Deallocating Tensor that still has live PyObject references. " - "This probably happened because you took out a weak reference to " - "Tensor and didn't call _fix_weakref() after dereferencing it. " - "Subsequent accesses to this tensor via the PyObject will now fail."); - (reinterpret_cast(pyobj))->cdata = - c10::MaybeOwned(); - } else if (THPStorage_Check(pyobj)) { - TORCH_WARN( - "Deallocating UntypedStorage that still has live PyObject references. " - "This probably happened because you took out a weak reference to " - "UntypedStorage and didn't call _fix_weakref() after dereferencing it. " - "Subsequent accesses to this storage via the PyObject will now fail."); - (reinterpret_cast(pyobj))->cdata = - c10::MaybeOwned(); - } - } Py_DECREF(pyobj); } @@ -292,6 +254,25 @@ void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const { Py_INCREF(pyobj); } +bool ConcretePyInterpreterVTable::try_incref( + const c10::impl::PyObjectSlot& pyobj_slot) const { + if (!Py_IsInitialized()) + return false; + pybind11::gil_scoped_acquire gil; + PyObject* pyobj = pyobj_slot.load_pyobj(); + if (!pyobj) { + return false; + } + return PyUnstable_TryIncRef(pyobj); +} + +size_t ConcretePyInterpreterVTable::refcnt(PyObject* pyobj) const { + if (!Py_IsInitialized() || pyobj == nullptr) + return 0; + pybind11::gil_scoped_acquire gil; + return Py_REFCNT(pyobj); +} + bool isPythonTensor(const at::Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); } @@ -620,11 +601,7 @@ static void set_tensor_attr_with_capsule( const c10::TensorImpl* tensor, py::capsule& capsule, const char* attr_name) { - std::optional mb_obj = tensor->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - auto obj = mb_obj.value(); + PyObject* obj = tensor->pyobj_slot()->load_pyobj(); py::handle(obj).attr(attr_name) = capsule; } @@ -648,11 +625,7 @@ static c10::ArrayRef get_set_cached_attr( const c10::TensorImpl* tensor, const char* base_attr_name, const py::object& obj) { - std::optional mb_obj = - tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); - TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - auto tensor_obj = mb_obj.value(); + PyObject* tensor_obj = tensor->pyobj_slot()->load_pyobj(); auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len"); bool is_buffer_allocated = false; diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 02558cbdf8968..671c28adef3e3 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -23,6 +23,8 @@ #include #include +using torch::utils::PyObjectPreservation; + template <> void THPPointer::free() { if (ptr) { @@ -32,238 +34,72 @@ void THPPointer::free() { PyTypeObject* THPStorageClass = nullptr; -PyObject* THPStorage_NewWithStorage( - PyTypeObject* type, - c10::Storage _storage, - bool allow_preexisting_pyobj) { - TORCH_CHECK( - PyType_IsSubtype(type, &THPStorageType), - "Creating a Storage subclass from a class that does not inherit from ", - "Storage is not possible. Make sure your class inherits from Storage."); - - auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (maybe_pyobj.has_value() && maybe_pyobj.value()) { - TORCH_CHECK( - allow_preexisting_pyobj, - "Creating a new Storage subclass ", - type->tp_name, - " but the raw Storage object is already associated to a python object ", - "of type ", - maybe_pyobj.value()->ob_type->tp_name); - PyObject* obj = *maybe_pyobj; - PyTypeObject* obj_type = Py_TYPE(obj); - TORCH_CHECK( - obj_type == type || PyType_IsSubtype(obj_type, type), - "Creating a new Storage subclass ", - type->tp_name, - " but the raw Storage object is already associated to a python object ", - "of type ", - maybe_pyobj.value()->ob_type->tp_name, - " which is not a subclass of the " - "requested type"); - return THPStorage_Wrap(std::move(_storage)); - } - +// Create a new Python Storage object, but don't set the pyobj slot on the +// c10::Storage object. +static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) { PyObject* obj = type->tp_alloc(type, 0); TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); - auto s = reinterpret_cast(obj); + // Ensure that PyUnstable_TryIncref calls don't fail spuriously in + // free-threaded Python. + PyUnstable_EnableTryIncRef(obj); - new (&s->cdata) c10::MaybeOwned(); - - s->cdata = c10::MaybeOwned::owned(std::move(_storage)); + auto s = (THPStorage*)obj; + new (&s->cdata) c10::Storage(std::move(_storage)); + return obj; +} - if (!c10::impl::HermeticPyObjectTLS::get_state()) { - s->is_hermetic = false; - const auto& storage = THPStorage_Unpack(s); - storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); - } else { - s->is_hermetic = true; - } +// Create a new Python Storage object for a new c10::Storage, and set the +// pyobj slot. The c10::Storage must not already have a pyobj set. +PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) { + TORCH_CHECK( + type == THPStorageClass || PyType_IsSubtype(type, &THPStorageType), + "Creating a Storage subclass from a class that does not inherit from ", + "Storage is not possible. Make sure your class inherits from Storage."); + TORCH_INTERNAL_ASSERT(_storage.use_count() == 1); + c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl(); + PyObject* obj = THPStorage_New(type, std::move(_storage)); + PyObjectPreservation::init_fresh_nonatomic( + storage_impl, storage_impl->pyobj_slot(), obj); return obj; } -// Wraps the c10::Storage with a storage PyObject +// Returns a PyObject wrapper for the c10::Storage object. The existing +// wrapper is returned if it already exists. PyObject* THPStorage_Wrap(c10::Storage storage) { - c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); + return THPStorage_New(THPStorageClass, std::move(storage)); } - c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); - std::optional maybe_pyobj = pyobj_slot->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (maybe_pyobj.has_value()) { - auto obj = *maybe_pyobj; - if (obj) { - TORCH_CHECK( - THPStorage_Check(obj), - "Expected a storage type, but got ", - Py_TYPE(obj)->tp_name); - - if (pyobj_slot->owns_pyobj()) { - pyobj_slot->set_owns_pyobj(false); - reinterpret_cast(obj)->cdata = - c10::MaybeOwned::owned(std::move(storage)); - return obj; - } else { - Py_INCREF(obj); - return obj; - } - } - } - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); -} - -static bool THPStorage_isPreservable(THPStorage* self) { - if (self->cdata.unsafeIsBorrowed()) { - return false; - } - auto const& storage = THPStorage_Unpack(self); - - if (self->is_hermetic) { - return false; - } + c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); + c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); - if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/true) != reinterpret_cast(self)) { - return false; - } - if (storage.use_count() <= 1) { - return false; + PyObject* obj = pyobj_slot->load_pyobj(); + if (obj) { + return Py_NewRef(obj); } - return true; -} -static bool THPStorage_tryPreserve(THPStorage* self) { - if (!THPStorage_isPreservable(self)) { - return false; + obj = THPStorage_New(THPStorageClass, std::move(storage)); + PyObject* wrapper = + PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj); + if (wrapper != obj) { + // Another thread beat us to it + Py_DECREF(obj); + return Py_NewRef(wrapper); } - - const auto& storage = THPStorage_Unpack(self); - c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); - - auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/true); - // NOTE: It is possible to just set the PyObjectSlot here, but the point is - // that we should have already set PyObjectSlot when the storage PyObject - // was created. - TORCH_INTERNAL_ASSERT( - maybe_pyobj.has_value(), - "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject"); - - PyObject* pyobj = *maybe_pyobj; - - TORCH_CHECK( - THPStorage_Check(pyobj), - "Expected a storage type, but got ", - Py_TYPE(pyobj)->tp_name); - - TORCH_INTERNAL_ASSERT( - (void*)pyobj == (void*)self, - "Python storage and the PyObject in the internal PyObjectSlot are not at the same address"); - - TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj()); - - storage_impl->pyobj_slot()->set_owns_pyobj(true); - // When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to - // ensure the PyObject is in a valid state - _Py_NewReference(reinterpret_cast(self)); - - self->cdata = c10::MaybeOwned::borrowed(storage); - return true; + return obj; } -static void THPStorage_subclass_dealloc(PyObject* self) { +static void THPStorage_dealloc(PyObject* self) { THPStorage* _self = reinterpret_cast(self); - - if (THPStorage_tryPreserve(_self)) { - return; - } - - // Some subclass of StorageBase could be GC-tracked objects even - // though the base class is not - auto* type = Py_TYPE(self); - if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) { - PyObject_GC_UnTrack(self); - } - - bool has_finalizer = type->tp_finalize || type->tp_del; - - if (type->tp_finalize) { - PyObject_GC_Track(self); - if (PyObject_CallFinalizerFromDealloc(self) < 0) { - // The finalizer has resurrected the PyObject and there is a new Python - // reference to it, so we can just stop deallocating. Read about - // resurrection from `__del__` here: - // https://docs.python.org/3/reference/datamodel.html#object.__del__ - return; - } - PyObject_GC_UnTrack(self); - } - - // base test is unnecessary as THPStorae does not set this - if (type->tp_weaklistoffset) { - PyObject_ClearWeakRefs(self); + auto pyobj_slot = _self->cdata.unsafeGetStorageImpl()->pyobj_slot(); + if (pyobj_slot->load_pyobj() == self) { + TORCH_INTERNAL_ASSERT(_self->cdata.use_count() == 1); + pyobj_slot->clear(); } - - if (type->tp_del) { - PyObject_GC_Track(self); - type->tp_del(self); - if (Py_REFCNT(self) > 0) { - // Resurrected (see above comment about resurrection from `__del__`) - return; - } - PyObject_GC_UnTrack(self); - } - - if (has_finalizer) { - /* New weakrefs could be created during the finalizer call. - If this occurs, clear them out without calling their - finalizers since they might rely on part of the object - being finalized that has already been destroyed. */ - if (type->tp_weaklistoffset) { - /* Modeled after GET_WEAKREFS_LISTPTR() */ - PyWeakReference** list = reinterpret_cast( - PyObject_GET_WEAKREFS_LISTPTR(self)); - while (*list) - _PyWeakref_ClearRef(*list); - } - } - - // Clear slots - { - PyTypeObject* base = type; - while (base != &THPStorageType) { - if (Py_SIZE(base)) { - clear_slots(base, self); - } - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - } - - // Clear __dict__ - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr(self); - if (dictptr != nullptr) { - PyObject* dict = *dictptr; - if (dict != nullptr) { - Py_DECREF(dict); - *dictptr = nullptr; - } - } - } - - TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type); - - _self->cdata.~MaybeOwned(); + _self->cdata.~Storage(); Py_TYPE(_self)->tp_free(self); - - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - Py_DECREF(type); } static PyObject* THPStorage_pynew( @@ -553,64 +389,13 @@ static PyMappingMethods THPStorage_mappingmethods = { reinterpret_cast(THPStorage_get), reinterpret_cast(THPStorage_set)}; -struct THPStorageMeta { - PyHeapTypeObject base; -}; - -static int THPStorageMetaType_init( - PyObject* cls, - PyObject* args, - PyObject* kwargs); - -static PyTypeObject THPStorageMetaType = { - PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0) - "torch._C._StorageMeta", /* tp_name */ - sizeof(THPStorageMeta), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - // NOLINTNEXTLINE(misc-redundant-expression) - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - DEFERRED_ADDRESS(&PyType_Type), /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - THPStorageMetaType_init, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ -}; - // TODO: implement equality PyTypeObject THPStorageType = { - PyVarObject_HEAD_INIT(&THPStorageMetaType, 0) + PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0) "torch._C.StorageBase", /* tp_name */ sizeof(THPStorage), /* tp_basicsize */ 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ + THPStorage_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ @@ -649,15 +434,6 @@ PyTypeObject THPStorageType = { THPStorage_pynew, /* tp_new */ }; -int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { - if (PyType_Type.tp_init(cls, args, kwargs) < 0) { - return -1; - } - (reinterpret_cast(cls))->tp_dealloc = - static_cast(THPStorage_subclass_dealloc); - return 0; -} - static PyObject* THPStorage_device(THPStorage* self, void* unused) { HANDLE_TH_ERRORS THPStorage_assertNotNull(self); @@ -692,13 +468,6 @@ bool THPStorage_init(PyObject* module) { THPUtils_addPyMethodDefs(methods, THPStorage_getMethods()); THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods()); - THPStorageMetaType.tp_base = &PyType_Type; - if (PyType_Ready(&THPStorageMetaType) < 0) - return false; - Py_INCREF(&THPStorageMetaType); - PyModule_AddObject( - module, "_StorageMeta", reinterpret_cast(&THPStorageMetaType)); - THPStorageType.tp_methods = methods.data(); THPStorageType.tp_getset = THPStorage_properties; if (PyType_Ready(&THPStorageType) < 0) diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 698cd80548efa..89e853181f3da 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -11,15 +11,13 @@ struct THPStorage { PyObject_HEAD - c10::MaybeOwned cdata; - bool is_hermetic; + c10::Storage cdata; }; TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( PyTypeObject* type, - c10::Storage _storage, - bool allow_preexisting_pyobj = false); + c10::Storage _storage); TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { @@ -49,7 +47,7 @@ TORCH_PYTHON_API void THPStorage_assertNotNull(PyObject* obj); TORCH_PYTHON_API extern PyTypeObject THPStorageType; inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) { - return *storage->cdata; + return storage->cdata; } inline const c10::Storage& THPStorage_Unpack(PyObject* obj) { diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 68c06f7c88c1c..178f735802fb7 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -529,9 +529,8 @@ static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) { THPUtils_typename(new_cdata)); c10::StorageImpl* ptr = static_cast(PyLong_AsVoidPtr(new_cdata)); - self->cdata.~MaybeOwned(); - self->cdata = c10::MaybeOwned::owned( - c10::Storage(c10::intrusive_ptr::reclaim_copy(ptr))); + self->cdata = + c10::Storage(c10::intrusive_ptr::reclaim_copy(ptr)); Py_INCREF(self); return reinterpret_cast(self); END_HANDLE_TH_ERRORS diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 97e689d36050c..8f55f22ae4ad4 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -180,7 +180,9 @@ struct TORCH_API AccumulateGrad : public Node { if (!GradMode::is_enabled() && !new_grad.is_sparse() && !new_grad.is_sparse_csr() && !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) && - at::caching::adjusted_use_count(new_grad) <= num_expected_refs && + impl::is_tensor_stealable( + new_grad, + num_expected_refs + at::caching::is_cached_tensor(new_grad)) && (new_grad.is_mkldnn() || utils::obeys_layout_contract(new_grad, variable))) { // See Case 1.1: Stealable dense new_grad @@ -193,7 +195,7 @@ struct TORCH_API AccumulateGrad : public Node { // SparseTensor should be the only one holding a reference to these. new_grad._indices().use_count() <= 1 && new_grad._values().use_count() <= 1 && - new_grad.use_count() <= num_expected_refs) { + impl::is_tensor_stealable(new_grad, num_expected_refs)) { // Case 1.2: Stealable sparse new_grad // No scenario where we expect this to be true currently TORCH_INTERNAL_ASSERT_DEBUG_ONLY( diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 62770ef946592..a477bf4c3e507 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -86,8 +86,8 @@ bool can_accumulate_inplace(const Variable& v) { v.is_non_overlapping_and_dense() && // and we hold the last reference - at::caching::adjusted_use_count(v) == 1 && v.has_storage() && - v.storage().use_count() == 1); + impl::is_tensor_stealable(v, 1 + at::caching::is_cached_tensor(v)) && + v.has_storage() && v.storage().use_count() == 1); } } // anonymous namespace diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 74c7a751fe960..6d0bf5d0a8579 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -54,6 +54,7 @@ using namespace at; using namespace torch; using namespace torch::autograd; +using torch::utils::PyObjectPreservation; namespace { class OperatorArgsKwargsView { @@ -321,20 +322,15 @@ PyObject* THPVariableClass = nullptr; PyObject* ParameterClass = nullptr; -static PyObject* THPVariable_NewWithVar( - PyTypeObject* type, - const at::TensorBase& _var, - bool allow_preexisting_pyobj = false, - std::optional has_torch_dispatch_if_known = std::nullopt); - // clang-tidy gets confused by static const static constexpr const char* VOLATILE_WARNING = "volatile was removed and now has no effect. Use " "`with torch.no_grad():` instead."; +static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls); + static bool check_has_torch_dispatch(PyObject* obj) { - PyTypeObject* tp = Py_TYPE(obj); - if (THPVariable_CheckTypeExact(tp)) { + if (THPVariable_CheckExact(obj)) { return false; } py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__"); @@ -370,152 +366,86 @@ void activateGPUTrace() { c10::impl::GPUTrace::set_trace(getPyInterpreter()); } -PyObject* THPVariable_Wrap(const at::TensorBase& var) { +static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) { + TORCH_CHECK( + PyObject_TypeCheck(obj, type), + "Creating a new Tensor subclass ", + type->tp_name, + " but the raw Tensor object is already associated to a python object ", + "of type ", + Py_TYPE(obj)->tp_name, + " which is not a subclass of the requested type"); +} + +// Generic for const Tensor& or Tensor&& +template +static PyObject* THPVariable_WrapWithType( + T&& var, + std::optional desired_type) { if (!var.defined()) { Py_RETURN_NONE; } - if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); - } - - std::optional mb_obj = - var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value()) { - auto obj = *mb_obj; - if (obj) { - if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - // C++ owns the Python object; this implies there weren't any other - // owning references to the Python object. Since we're making the - // object "live" again on Python side, let's flip back the ownership - // (Python owns C++) as it would now be unsound to deallocate the C++ - // object if all C++ references go to zero - var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false); - reinterpret_cast(obj)->cdata = - MaybeOwned::owned(Variable(var)); - // NB: incref is not necessary, because we are "stealing" the previous - // ownership from the Variable to return it here for the wrap - return obj; - } - Py_INCREF(obj); - return obj; + c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl(); + c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot(); + + PyObject* obj = pyobj_slot->load_pyobj(); + if (obj) { + if (desired_type) { + check_tensor_subclass(obj, *desired_type); } - // TODO: a better invariant is that if we tagged, we MUST have a valid - // PyObject. That's PyObject preservation - // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR - // being a thing, the PyObject field will get cleared when all references - // to the Python object are removed. + return Py_NewRef(obj); } - if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); + PyTypeObject* type = reinterpret_cast(THPVariableClass); + if (desired_type) { + type = *desired_type; + } else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) { + if (auto clazz = getPythonTensorClass(var.device())) { + type = reinterpret_cast(clazz); + } } - if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, var); - } + obj = type->tp_alloc(type, 0); + TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); -} + // Ensure that PyUnstable_TryIncref calls don't fail spuriously in + // free-threaded Python. + PyUnstable_EnableTryIncRef(obj); -static bool isResurrectable(THPVariable* self) { - // We want to divide this check into 2 cases. - - // 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is - // true). You might think that in this case, it is impossible for tp_clear to - // be called: surely the C++ reference to the PyObject is keeping it live? And - // you'd be right! In fact, when C++ owns the PyObject, we have an invariant - // that the refcount on the PyObject should be precisely one (because if you - // take out another reference to the PyObject, we're supposed to flip the - // ownership pointer back). In reality, you can violate this invariant - // temporarily with weak references, so we don't test for it in asserts. - - // 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is - // false). In this case, tp_clear can get called if the PyObject is referenced - // from a dead cycle, and nowhere else. But if resurrection did not occur, - // then the reference to C++ from the PyObject must be the ONLY reference to - // the C++ object. - if (self->cdata.unsafeIsBorrowed()) { - return false; - } - auto const& tensor = THPVariable_Unpack(self); - if (!tensor.defined() || tensor.use_count() <= 1) { - return false; - } - // Check if this is hermetic. If it is, no resurrection. - if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false) != (PyObject*)self) { - return false; - } - return true; -} + auto v = reinterpret_cast(obj); + new (&v->cdata) Tensor(std::forward(var)); -// returns true if successfully rezzed; if so, cancel the -// rest of deallocation -static bool THPVariable_tryResurrect(THPVariable* self) { - const auto& tensor = THPVariable_Unpack(self); - - if (!isResurrectable(self)) { - return false; + if (THPVariable_Unpack(obj).is_uniquely_owned()) { + // We can use a faster non-atomic code path if we have the only reference to + // a fresh Tensor. + PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj); + return obj; } - // At this point, we are definitely going to resurrect the tensor. So, the - // tensor better be defined :) - TORCH_INTERNAL_ASSERT(tensor.defined()); - - // There are other C++ owners of the tensor. Flip ownership - // so that C++ owns this Python object, and cancel deallocation. - TORCH_INTERNAL_ASSERT( - !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); - - c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); - auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - - TORCH_INTERNAL_ASSERT( - maybe_pyobj.has_value(), - "Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject"); - - tensor_impl->pyobj_slot()->set_owns_pyobj(true); - - // Resurrect the Python object. This is something CPython does - // internally occasionally, see - // https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259 - // so we just copy the pattern here. Note that we don't have to worry - // about saving and restoring the refcount (as the quoted code does) - // because we actually DO need to reset the refcount to one here, we - // can't assume that some other code has taken care of it. - // NB: this will overreport _Py_RefTotal but based on inspection of object.c - // there is no way to avoid this - - // When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to - // ensure the PyObject is in a valid state - _Py_NewReference((PyObject*)self); - - // Flip THPVariable to be non-owning - // (near use-after-free miss here: fresh MaybeOwned is created breaking - // reference on Tensor in struct BEFORE we overwrite the old one) - TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state()); - self->cdata = MaybeOwned::borrowed(tensor); - - // NB: At this point, tensor *could* be dead (e.g., some other C++ thread - // decrefed it.) At this point, it is probably waiting on the GIL to - // deallocate the Python object and will kill self, BUT NOT YET. + PyObject* wrapper = + PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj); + if (wrapper != obj) { + // Another thread beat us to it + Py_DECREF(obj); + if (desired_type) { + check_tensor_subclass(wrapper, *desired_type); + } + return Py_NewRef(wrapper); + } + return obj; +} - return true; +PyObject* THPVariable_Wrap(at::TensorBase&& var) { + return THPVariable_WrapWithType(std::move(var), std::nullopt); } -static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) { - TORCH_INTERNAL_ASSERT( - false, "TensorBase tp_traverse function was not overridden properly"); - return 0; +PyObject* THPVariable_Wrap(const at::TensorBase& var) { + return THPVariable_WrapWithType(var, std::nullopt); } -static int THPFake_clear(THPVariable* self) { - TORCH_INTERNAL_ASSERT( - false, "TensorBase tp_clear function was not overridden properly"); - return 0; +PyObject* THPVariable_Wrap(const at::TensorBase& var, PyTypeObject* type) { + return THPVariable_WrapWithType(var, type); } static PyObject* THPVariable_pynew( @@ -677,16 +607,16 @@ static PyObject* THPVariable_as_subclass( ParsedArgs<1> parsed_args{}; auto r = parser.parse(_self, args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); // guard completely turns off torch dispatch modes, doesn't just pop off the // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; - return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); + PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls); + if (check_has_torch_dispatch(obj)) { + THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); + } + return obj; END_HANDLE_TH_ERRORS } @@ -701,11 +631,7 @@ static PyObject* THPVariable_make_subclass( ParsedArgs<7> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); // guard completely turns off torch dispatch modes, doesn't just pop off the // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; @@ -738,7 +664,11 @@ static PyObject* THPVariable_make_subclass( data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); } - return THPVariable_NewWithVar((PyTypeObject*)cls, data); + PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls); + if (check_has_torch_dispatch(obj)) { + THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); + } + return obj; END_HANDLE_TH_ERRORS } @@ -835,11 +765,7 @@ static PyObject* THPVariable_make_wrapper_subclass( auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); // This is an important safety check; without it, the default behavior will be // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch @@ -877,6 +803,8 @@ static PyObject* THPVariable_make_wrapper_subclass( /*storage_size=*/r.toSymIntOptional(14), r.toDispatchKeySetOptional(13)); + tensor.unsafeGetTensorImpl()->set_python_dispatch(true); + const auto sizes_strides_policy = r.stringViewOptional(10); if (sizes_strides_policy.has_value()) { tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( @@ -892,13 +820,7 @@ static PyObject* THPVariable_make_wrapper_subclass( tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); } - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - tensor, - // false is the default - /*allow_preexisting_pyobj=*/false, - // we checked __torch_dispatch__ above; avoid checking again. - /*has_torch_dispatch_if_known=*/true); + return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls); END_HANDLE_TH_ERRORS } @@ -1699,11 +1621,7 @@ static PyObject* THPVariable_dtensor_new( auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); #ifndef NDEBUG // This is specifically for making a DTensor, which we know defines @@ -1756,14 +1674,9 @@ static PyObject* THPVariable_dtensor_new( /*storage_size=*/std::nullopt, extra_dispatch_keys); tensor.set_requires_grad(requires_grad); - py::object py_tensor = - py::reinterpret_steal(THPVariable_NewWithVar( - (PyTypeObject*)cls, - tensor, - // false is the default - /*allow_preexisting_pyobj=*/false, - // we know DTensor has __torch_dispatch__; avoid checking again. - /*has_torch_dispatch_if_known=*/true)); + tensor.unsafeGetTensorImpl()->set_python_dispatch(true); + py::object py_tensor = py::reinterpret_steal( + THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls)); py_tensor.attr(dtensor_interned_strings._spec) = spec; py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor; return py_tensor.release().ptr(); @@ -3440,15 +3353,16 @@ static PyTypeObject THPVariableMetaType = { nullptr, /* tp_new */ }; +static void THPVariable_dealloc(PyObject* self); +static int THPVariable_clear(THPVariable* self); +static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg); + static PyTypeObject THPVariableType = { PyVarObject_HEAD_INIT(&THPVariableMetaType, 0) "torch._C.TensorBase", /* tp_name */ sizeof(THPVariable), /* tp_basicsize */ 0, /* tp_itemsize */ - // This is unspecified, because it is illegal to create a THPVariableType - // directly. Subclasses will have their tp_dealloc set appropriately - // by the metaclass - nullptr, /* tp_dealloc */ + THPVariable_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ @@ -3467,9 +3381,8 @@ static PyTypeObject THPVariableType = { Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ - // Also set by metaclass - (traverseproc)THPFake_traverse, /* tp_traverse */ - (inquiry)THPFake_clear, /* tp_clear */ + (traverseproc)THPVariable_traverse, /* tp_traverse */ + (inquiry)THPVariable_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ @@ -3498,345 +3411,68 @@ PyObject* THPVariable_pynew( type != &THPVariableType, "Cannot directly construct TensorBase; subclass it and then construct that"); jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); - auto tensor = torch::utils::base_tensor_ctor(args, kwargs); // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was // given a raw pointer that will refcount bump // NB: base_tensor_ctor can call into dispatched ATen functions (e.g., // alias(), lift_fresh()) which can return Tensor subclasses. We allow // these to be passed on directly. - return THPVariable_NewWithVar( - type, - tensor, - /*allow_preexisting_pyobj=*/true); + PyObject* obj = THPVariable_WrapWithType( + torch::utils::base_tensor_ctor(args, kwargs), type); + if (check_has_torch_dispatch(obj)) { + THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); + } + return obj; END_HANDLE_TH_ERRORS } -static int THPVariable_subclass_clear(THPVariable* self) { - // Is it OK for an object to still be live after running - // tp_clear? Yes. When Python is breaking reference cycles, it can't assume - // that an object will dealloc after it's cleared. The source code explicitly - // handles this case: - // https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025 - - // Note that we don't need to actually resurrect here. There are 2 cases: - // 1. The PyObject is not part of a reference cycle. In this case, we don't - // need to do anything. The GC will move on to try and break the reference - // cycle on another object, which will eventually trigger tp_dealloc (and thus - // resurrection). - - // 2. The PyObject is part of a reference cycle. This case should not actually - // be possible, due to the logic in our tp_traverse - // (THPVariable_subclass_traverse). - - // In fact, resurrecting here breaks the invariant that "C++ owns Python only - // when PyObject's refcount would otherwise be 0". Most immediately, as we're - // merely breaking reference cycles here, there can be other references to the - // PyObject. *However*, if other objects in the refcycle resurrect, then we - // will be in a state where the PyObject has multiple Python references, yet - // C++ owns the PyObject. - - // See https://github.com/pytorch/pytorch/pull/75933 for more discussion. - if (isResurrectable(self)) { - return 0; - } - +static int THPVariable_clear(THPVariable* self) { // First clear Tensor specific things - Py_CLEAR(self->backward_hooks); Py_CLEAR(self->post_accumulate_grad_hooks); - const auto& tensor = THPVariable_Unpack(self); - if (tensor.defined()) { - // Two situations to consider: - // PyObject -owns-> Tensor - // unsafeIsBorrowed() is FALSE. We're obligated to look through - // Tensor to break references. Clearing cdata must induce the - // destruction of the C++ Tensor. If there were other references - // to C++ tensor, the Python object would have been resurrected - // by flipping the ownership. - // Tensor -owns-> PyObject - // unsafeIsBorrowed() is TRUE. We're deallocating the PyObject - // because Tensor asked us to (it's already destructing). - - if (!self->cdata.unsafeIsBorrowed() && - tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false) == (PyObject*)self) { - // TODO: empirically, on OS X this assert appears to be untrue - // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn - // distributed/rpc/test_process_group_agent.py - // - // libc++abi.dylib: terminating with uncaught exception of type - // c10::Error: - // !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL - // ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171, - // please report a bug to PyTorch. Exception raised from - // THPVariable_subclass_clear at - // ../torch/csrc/autograd/python_variable.cpp:171 (most recent call - // first): frame #0: c10::Error::Error(c10::SourceLocation, - // std::__1::basic_string, - // std::__1::allocator >) + 98 (0x1158a0442 in libc10.dylib) frame - // #1: c10::detail::torchCheckFail(char const*, char const*, unsigned - // int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2: - // c10::detail::torchInternalAssertFail(char const*, char const*, - // unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9 - // (0x1141e3f89 in libtorch_python.dylib) frame #3: - // THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in - // libtorch_python.dylib) frame #4: - // THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in - // libtorch_python.dylib) frame #5: (anonymous - // namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*, - // _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6: - // c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in - // libc10.dylib) frame #7: - // c10::MaybeOwned::operator=(c10::MaybeOwned&&) - // + 91 (0x11488c11b in libtorch_python.dylib) frame #8: - // THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in - // libtorch_python.dylib) frame #47: start + 1 - // (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???) - // TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); - if (auto grad_acc = - torch::autograd::impl::try_get_grad_accumulator(tensor)) { - grad_acc->pre_hooks().clear(); - grad_acc->tensor_pre_hooks().clear(); - grad_acc->retains_grad_hooks().clear(); - } + if (self->cdata.defined()) { + auto pyobj_slot = self->cdata.unsafeGetTensorImpl()->pyobj_slot(); + // Typically the Tensor's pyobj_slot points back to this object. The only + // time that's not the case is if we had a race in THPVariable_Wrap and we + // need to discard the Python object because some other thread beat us to + // setting the pyobj_slot. + if (pyobj_slot->load_pyobj() == (PyObject*)self) { + // A Tensor's Python object should only be destroyed when the Tensor has + // no other references too. + TORCH_INTERNAL_ASSERT(self->cdata.use_count() == 1); + + // Clear the pyobj_slot so that a try_incref() call from + // weak_intrusive_ptr::lock() won't see a freed pointer. + pyobj_slot->clear(); } } - TORCH_INTERNAL_ASSERT(!isResurrectable(self)); { // MapAllocator can take significant time to release large tensors; // release the GIL here to avoid impacting main thread perf. pybind11::gil_scoped_release no_gil; - self->cdata = MaybeOwned(); + self->cdata = Variable(); } - // Since we override the basic subtype_clear from CPython, we need a crappy - // version here just like for traverse and dealloc - - // Clear all slots until we get to the base Tensor class - PyTypeObject* type = Py_TYPE((PyObject*)self); - PyTypeObject* base = type; - while (base != &THPVariableType) { - if (Py_SIZE(base)) - clear_slots(base, (PyObject*)self); - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - - // Assume we never have managed dict for Tensors as we don't set the flag on - // the base class - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr((PyObject*)self); - if (dictptr && *dictptr) - Py_CLEAR(*dictptr); - } - return 0; } -// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc -// on subclasses. It's never valid to construct a THPVariable so it's not -// necessary to implement the dealloc for that case -static void THPVariable_subclass_dealloc(PyObject* self) { - if (THPVariable_tryResurrect((THPVariable*)self)) - return; - - // This is like a crappy version of subtype_dealloc. - // Unfortunately, we cannot directly delegate to - // subtype_dealloc as it will start walking the parent - // chain *starting with* the type of self, which will cause - // us to go back to our custom dealloc. - // - // We have to replicate the subtype_dealloc logic to ensure - // that finalizers are handled correctly - PyTypeObject* type = Py_TYPE(self); - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented"); - +static void THPVariable_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); - // TODO: consider using trash can - - bool has_finalizer = type->tp_finalize || type->tp_del; - - if (type->tp_finalize) { - PyObject_GC_Track(self); - if (PyObject_CallFinalizerFromDealloc(self) < 0) { - /* Resurrected */ - return; - } - PyObject_GC_UnTrack(self); - } - - // base test is unnecessary as THPVariable does not set this - if (type->tp_weaklistoffset) { - PyObject_ClearWeakRefs(self); - } - - if (type->tp_del) { - PyObject_GC_Track(self); - type->tp_del(self); - if (Py_REFCNT(self) > 0) { - /* Resurrected */ - return; - } - PyObject_GC_UnTrack(self); - } - - if (has_finalizer) { - /* New weakrefs could be created during the finalizer call. - If this occurs, clear them out without calling their - finalizers since they might rely on part of the object - being finalized that has already been destroyed. */ - if (type->tp_weaklistoffset) { - /* Modeled after GET_WEAKREFS_LISTPTR() */ - PyWeakReference** list = - (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self); - while (*list) - _PyWeakref_ClearRef(*list); - } - } - - // Clear all slots until we get to base class THPVariableType - { - PyTypeObject* base = type; - while (base != &THPVariableType) { - if (Py_SIZE(base)) { - clear_slots(base, self); - } - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - } - - // All Python defined classes have __dict__ - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr(self); - if (dictptr != nullptr) { - PyObject* dict = *dictptr; - if (dict != nullptr) { - Py_DECREF(dict); - *dictptr = nullptr; - } - } - } - - // subtype_dealloc allows for this but we don't - TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type); - - // Finally clear out the base THPVariable - THPVariable_subclass_clear((THPVariable*)self); - ((THPVariable*)self)->cdata.~MaybeOwned(); + THPVariable_clear((THPVariable*)self); + ((THPVariable*)self)->cdata.~Variable(); Py_TYPE(self)->tp_free(self); - - // Python defined subclasses should always be on the heap - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - Py_DECREF(type); } -// Creates a new Python object for a Variable. -static PyObject* THPVariable_NewWithVar( - PyTypeObject* type, - const at::TensorBase& _var, - bool allow_preexisting_pyobj, - std::optional has_torch_dispatch_if_known) { - // Make sure that the reinterpret into a THPVariable* will be valid - TORCH_CHECK( - type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType), - "Creating a Tensor subclass from a class ", - "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor."); - - // This function overwrite the Tensor's pyobj field without extra checks - // Make sure it is not set otherwise we would leak memory - auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - - // Under some circumstances, we may attempt to create a new Python - // object for a variable that already has a Python object. The most common - // situation this can occur is if you have a TorchDispatchMode active that - // is returning a subclass from lift_fresh (which is invoked to - // appropriately "wrap" a constant tensor into whatever ambient modes are - // active.) - // - // In general, it is impossible to handle this case compositionally. - // Suppose you have a user call ATensor([1, 2, 3]) when a mode is active - // that is transforming all ops (including the internal lift_fresh call that - // transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output - // BTensor, where ATensor and BTensor are completely unrelated subclasses - // and there is no way to compose them. There is no way to satisfy the user - // request here: in particular, you can't just try to re-invoke the ATensor - // constructor on the returned BTensor, because (1) this could cause an - // infinite loop--we are already in ATensor.__new__ and (2) there isn't any - // guarantee that ATensor.__new__ supports a single element constructor - // anyway. - // - // However, a more common case is a user just called torch.Tensor([1, 2, 3]), - // and a fake tensor mode is active. Really, all you want is to get back - // a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3) - // would have returned a fake tensor (concretely, the way this happens - // is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it - // turns into a FakeTensor when we call lift_fresh on this real tensor). - // This case is compositional because FakeTensor is a subclass of Tensor, so - // it's valid for us to return it in place of a Tensor. So this is what we - // do. - - if (mb_obj.has_value() && mb_obj.value()) { - TORCH_CHECK( - allow_preexisting_pyobj, - "Creating a new Tensor subclass ", - type->tp_name, - " but the raw Tensor object is already associated to a python object ", - "of type ", - mb_obj.value()->ob_type->tp_name); - // Even if we allow pre-existing PyObject, we don't allow completely - // ignoring the requested type. Check that we fulfilled a subtype - // relation here. In the common case the requested type is Tensor and - // this always succeeds. - PyObject* obj = *mb_obj; - // Check if it's OK to just directly return the Python object without - // allocating a new variable. We just check that the existing Python - // object is a subclass of the requested type. - PyTypeObject* obj_type = Py_TYPE(obj); - TORCH_CHECK( - obj_type == type || PyType_IsSubtype(obj_type, type), - "Creating a new Tensor subclass ", - type->tp_name, - " but the raw Tensor object is already associated to a python object ", - "of type ", - mb_obj.value()->ob_type->tp_name, - " which is not a subclass of the " - "requested type"); - // We may (in fact, we typically will) need to resurrect this - return THPVariable_Wrap(_var); - } - - PyObject* obj = type->tp_alloc(type, 0); - if (obj) { - auto v = (THPVariable*)obj; - // TODO: named constructor to avoid default initialization - new (&v->cdata) MaybeOwned(); - if (c10::impl::HermeticPyObjectTLS::get_state()) { - // Do NOT initialize pyobj field on the tensor, you own the C++ - v->cdata = MaybeOwned::owned(Variable(_var)); - TORCH_INTERNAL_ASSERT( - !check_has_torch_dispatch(obj), - "While HermeticPyObject was enabled, we attempted to create a tensor " - "subclass with __torch_dispatch__. This violates the invariant that " - "operations in HermeticPyObject have equivalent C++ implementations. " - "If your operator registered from Python operator registration isn't " - "doing anything strange, there may be an internal PyTorch bug involving " - "not appropriately disabling TorchDispatchMode before executing " - "Python op registration."); - } else { - // Normal codepath - v->cdata = MaybeOwned::owned(Variable(_var)); - const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); - if (has_torch_dispatch_if_known.has_value() - ? *has_torch_dispatch_if_known - : check_has_torch_dispatch(obj)) { - var.unsafeGetTensorImpl()->set_python_dispatch(true); - } - } - } - return obj; +static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) { + TORCH_CHECK_TYPE( + PyType_Check(cls), + "cls must be a type (got ", + Py_TYPE(cls)->tp_name, + ")"); + PyTypeObject* type = reinterpret_cast(cls); + TORCH_CHECK_TYPE( + type == &THPVariableType || cls == THPVariableClass || + PyType_IsSubtype(type, &THPVariableType), + "Creating a Tensor subclass from a class that does not inherit from " + "Tensor is not possible. Make sure your class inherits from Tensor."); } /// NOTE [ PyObject Traversal ] @@ -3855,7 +3491,7 @@ static PyObject* THPVariable_NewWithVar( /// into account these C++ ownership links. /// /// The main danger here comes from the fact that, while all python-related code -/// is thread safe wrt the GC execution (thanks to the GIL), other threads might +/// is thread safe wrt the GC execution, other threads might /// be using our C++ objects arbitrarily which can lead to shared_ptr ref count /// going up or down in between the different traverse/clear invocations. The /// one constraint we add here that is not explicitly mentioned in the GC @@ -3885,124 +3521,46 @@ static PyObject* THPVariable_NewWithVar( /// https://github.com/pytorch/pytorch/issues/7343 /// -static int traverse_slots( - PyTypeObject* type, - PyObject* self, - visitproc visit, - void* arg) { - auto n = Py_SIZE(type); - auto mp = type->tp_members; - for (Py_ssize_t i = 0; i < n; i++, mp++) { - if (mp->type == T_OBJECT_EX) { - char* addr = (char*)self + mp->offset; - PyObject* obj = *(PyObject**)addr; - if (obj != nullptr) { - int err = visit(obj, arg); - if (err) - return err; - } - } - } - return 0; -} - -static int THPVariable_subclass_traverse( - PyObject* self, - visitproc visit, - void* arg) { - // If the tensor is eligible to be resurrected, don't traverse it; instead - // treat all of its references as a root (as they WOULD be a root since we - // can treat the inbound C++ references as root owners). - // - // This works because unlike conventional GCs, Python's GC operates in two - // phases: first it uses traverse to discover roots, and then it uses traverse - // to do reachability. Bypassing traverse during root discovery forces Python - // to treat self as a root for everything it refers to. For a full - // explanation of the algorithm see - // https://devguide.python.org/garbage_collector/ - // - // NB: if we don't hold an owning reference to the underlying Tensor, it is - // possible that the underlying Tensor has already gone dead. In that case, - // it's not safe to access it. But it's also safe to traverse, because if - // the underlying Tensor *is* live, then root discovery will determine that - // self is live, and nothing will get GC'ed anyway (resurrection cannot happen - // if the C++ objects owns the PyObject) +static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg) { THPVariable* var = reinterpret_cast(self); - if (isResurrectable(var)) { - return 0; - } - - // Crappy version of subtype_traverse; same deal as - // THPVariable_subclass_dealloc - - PyTypeObject* type = Py_TYPE(self); - // Traverse slots until we get to base class THPVariableType - { - PyTypeObject* base = type; - while (base != &THPVariableType) { - if (Py_SIZE(base)) { - int err = traverse_slots(base, self, visit, arg); - if (err) - return err; - } - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - } - - // All Python defined classes have __dict__ - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr(self); - if (dictptr && *dictptr) - Py_VISIT(*dictptr); - } - - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - Py_VISIT(type); - - // Finally traverse THPVariable special stuff Py_VISIT(var->backward_hooks); Py_VISIT(var->post_accumulate_grad_hooks); - if (!var->cdata.unsafeIsBorrowed()) { - const auto& tensor = THPVariable_Unpack(var); - if (tensor.defined()) { - // WARNING: The grad_fn traversal logic is very subtle, if you change - // this, be very careful not to re-introduce this bug: - // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c - - // We ensure that we follow NOTE [ PyObject Traversal ] he by checking - // that this python object is the sole owner of the underlying Tensor and - // that this Tensor is the sole owner of its grad_fn. In this case, the - // only way to get a new reference to the grad_fn is by using this python - // object, which requires the GIL to be accessed. Note that this is only - // valid as long as user don't share non-owning references across - // different threads (which is crazy and should never be done). - auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); - if (tensor.use_count() == 1) { - if (autograd_meta) { - // Do NOT call grad_fn() here as that might trigger a recompute - const auto& grad_fn = autograd_meta->grad_fn_; - if (grad_fn && grad_fn.use_count() == 1) { - // All Node can have a pyobj (stored in "pyobj_") - Py_VISIT(grad_fn->pyobj()); - // PyNode are special as they also have an "obj" field - if (auto py_node_fn = dynamic_cast(grad_fn.get())) { - Py_VISIT(py_node_fn->obj); - } + const auto& tensor = THPVariable_Unpack(var); + if (tensor.defined()) { + // WARNING: The grad_fn traversal logic is very subtle, if you change + // this, be very careful not to re-introduce this bug: + // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c + + // We ensure that we follow NOTE [ PyObject Traversal ] he by checking + // that this python object is the sole owner of the underlying Tensor and + // that this Tensor is the sole owner of its grad_fn. In this case, the + // only way to get a new reference to the grad_fn is by using this python + // object, which requires the GIL to be accessed. Note that this is only + // valid as long as user don't share non-owning references across + // different threads (which is crazy and should never be done). + auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); + if (tensor.use_count() == 1) { + if (autograd_meta) { + // Do NOT call grad_fn() here as that might trigger a recompute + const auto& grad_fn = autograd_meta->grad_fn_; + if (grad_fn && grad_fn.use_count() == 1) { + // All Node can have a pyobj (stored in "pyobj_") + Py_VISIT(grad_fn->pyobj()); + // PyNode are special as they also have an "obj" field + if (auto py_node_fn = dynamic_cast(grad_fn.get())) { + Py_VISIT(py_node_fn->obj); } } } - if (autograd_meta) { - for (const auto& hook : torch::autograd::impl::hooks(tensor)) { - if (auto pyhook = - dynamic_cast(hook.get())) { - Py_VISIT(pyhook->dict); - } + } + if (autograd_meta) { + for (const auto& hook : torch::autograd::impl::hooks(tensor)) { + if (auto pyhook = dynamic_cast(hook.get())) { + Py_VISIT(pyhook->dict); } } } } - return 0; } @@ -4010,17 +3568,6 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { if (PyType_Type.tp_init(cls, args, kwargs) < 0) { return -1; } - // It is important for all three of these to be overridden correctly for the - // resurrection checks to properly happen. In particular, an older version - // was not overriding tp_clear here. This lead to the default subtype_clear - // running on the Tensor object (as only TensorBase tp_clear was custom), - // clearing the __dict__ field, before the TensorBase custom clear was called - // and would properly detect the resurrect. - // See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior - ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc; - ((PyTypeObject*)cls)->tp_traverse = - (traverseproc)THPVariable_subclass_traverse; - ((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear; // Don't do anything for the base Tensor class if (!THPVariableClass) { diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 1b2116ec1ee6b..5b6f089990693 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -17,7 +17,7 @@ namespace py = pybind11; struct THPVariable { PyObject_HEAD // Payload - c10::MaybeOwned cdata; + at::Tensor cdata; // Hooks to be run on backwards pass (corresponds to Python attr // '_backwards_hooks', set by 'register_hook') PyObject* backward_hooks = nullptr; @@ -37,7 +37,11 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass; TORCH_PYTHON_API extern PyObject* ParameterClass; bool THPVariable_initModule(PyObject* module); +TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase&& var); TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var); +TORCH_PYTHON_API PyObject* THPVariable_Wrap( + const at::TensorBase& var, + PyTypeObject* type); inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass. @@ -69,7 +73,7 @@ inline bool THPVariable_Check(PyObject* obj) { } inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { - return *var->cdata; + return var->cdata; } inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index ed97dc4530eb4..00bdb91c36867 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -65,7 +65,9 @@ inline at::Tensor clone_obey_contract( .new_empty_strided_symint( variable.sym_sizes(), variable.sym_strides(), - variable.options().memory_format(std::nullopt)) + variable.options() + .memory_format(std::nullopt) + .dtype(new_grad.dtype())) .copy_(new_grad)); } else { // (2) diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index 6e0494df5cf47..616b0fa0331bc 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -70,6 +70,10 @@ inline PyObject* wrap(const at::Tensor& tensor) { return THPVariable_Wrap(tensor); } +inline PyObject* wrap(at::Tensor&& tensor) { + return THPVariable_Wrap(std::move(tensor)); +} + inline PyObject* wrap(const at::Scalar& scalar) { return wrap(scalar_to_tensor(scalar)); } diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index a297a9f5ef425..05dbfdaa44325 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -197,6 +197,22 @@ TORCH_API std::unique_ptr& post_acc_grad_hooks( TORCH_API void create_cpp_hook( const at::TensorBase& /*self*/, bool is_retains_grad_hooks = false); + +inline bool is_tensor_stealable( + const at::Tensor& new_grad, + size_t num_expected_refs = 1) { + size_t use_count = new_grad.use_count(); + if (use_count <= num_expected_refs) { + return true; + } + if (use_count >= 2 && + new_grad.unsafeGetTensorImpl()->pyobj_slot()->has_unique_reference()) { + // The Python wrapper, if it exists, also has a reference to the Tensor. + num_expected_refs++; + } + return use_count <= num_expected_refs; +} + } // namespace impl //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -894,7 +910,7 @@ inline Variable make_variable( bool requires_grad = false, bool allow_tensor_metadata_change = true) { if (data.defined()) { - if (data.getIntrusivePtr().use_count() == 1 && + if (impl::is_tensor_stealable(data) && data.getIntrusivePtr()->unique_version()) { auto data_impl = data.unsafeReleaseIntrusivePtr(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 188f92557761d..05d7aa04425f5 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -1,6 +1,5 @@ #if !defined(C10_MOBILE) && !defined(ANDROID) -#include #include #include #include @@ -31,6 +30,8 @@ namespace fs = std::filesystem; #include #include #include +#define access _access +#define F_OK 0 #else #include #include @@ -78,6 +79,15 @@ std::string normalize_path_separator(const std::string& orig_path) { return normalized_path; } +bool file_exists(const std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc{}; + return lstat(path.c_str(), &rc) == 0; +#endif +} + std::string create_temp_dir() { #ifdef _WIN32 try { @@ -145,8 +155,7 @@ namespace torch::inductor { namespace { const nlohmann::json& load_json_file(const std::string& json_path) { - TORCH_CHECK( - c10::filesystem::exists(json_path), "File not found: ", json_path); + TORCH_CHECK(file_exists(json_path), "File not found: ", json_path); std::ifstream json_file(json_path); TORCH_CHECK(json_file.is_open()); @@ -283,6 +292,102 @@ std::tuple get_cpp_compile_command( return std::make_tuple(cmd, target_file); } +bool recursive_mkdir(const std::string& dir) { + // Creates directories recursively, copied from jit_utils.cpp + // Check if current dir exists + const char* p_dir = dir.c_str(); + const bool dir_exists = (access(p_dir, F_OK) == 0); + if (dir_exists) { + return true; + } + + // Try to create current directory +#ifdef _WIN32 + int ret = _mkdir(dir.c_str()); +#else + int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + // Success + if (ret == 0) { + return true; + } + + // Find folder separator and check if we are at the top + auto pos = dir.find_last_of(k_separator); + if (pos == std::string::npos) { + return false; + } + + // Try to create parent directory + if (!(recursive_mkdir(dir.substr(0, pos)))) { + return false; + } + + // Try to create complete path again +#ifdef _WIN32 + ret = _mkdir(dir.c_str()); +#else + ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + return ret == 0; +} + +bool recursive_rmdir(const std::string& path) { +#ifdef _WIN32 + std::error_code ec; + return fs::remove_all(path, ec) != static_cast(-1); +#else + DIR* dir = opendir(path.c_str()); + if (!dir) { + return false; + } + + struct dirent* entry = nullptr; + struct stat statbuf{}; + bool success = true; + + // Iterate through directory entries + while ((entry = readdir(dir)) != nullptr) { + std::string name = entry->d_name; + + // Skip "." and ".." + if (name == "." || name == "..") { + continue; + } + + std::string full_path = path; + full_path.append("/").append(name); + + // Get file status + if (stat(full_path.c_str(), &statbuf) != 0) { + success = false; + continue; + } + + if (S_ISDIR(statbuf.st_mode)) { + // Recursively delete subdirectory + if (!recursive_rmdir(full_path)) { + success = false; + } + } else { + // Delete file + if (unlink(full_path.c_str()) != 0) { + success = false; + } + } + } + + closedir(dir); + + // Remove the directory itself + if (rmdir(path.c_str()) != 0) { + success = false; + } + + return success; +#endif +} + std::string compile_so( const std::string& cpp_filename, std::vector& obj_filenames) { @@ -312,7 +417,7 @@ std::string compile_so( // Move the mmapped weights onto the .so std::string serialized_weights_path = filename + "_serialized_weights.bin"; - if (c10::filesystem::exists(serialized_weights_path)) { + if (file_exists(serialized_weights_path)) { std::ifstream serialized_weights_file( serialized_weights_path, std::ios::binary); TORCH_CHECK( @@ -534,13 +639,11 @@ std::unordered_map AOTIModelPackageLoader:: parent_path_idx != std::string::npos, "Failed to find parent path in " + output_path_str); std::string parent_path = output_path_str.substr(0, parent_path_idx); - std::error_code ec{}; - c10::filesystem::create_directories(parent_path, ec); TORCH_CHECK( - ec.value() == 0, + recursive_mkdir(parent_path), "Failed to create directory " + parent_path, ": ", - ec.message()); + c10::utils::str_error(errno)); LOG(INFO) << "Extract file: " << metadata_filename << " to " << output_path_str; @@ -554,7 +657,7 @@ std::unordered_map AOTIModelPackageLoader:: metadata[item.key()] = item.value().get(); } // Clean up temporary directory - c10::filesystem::remove_all(temp_dir, ec); + recursive_rmdir(temp_dir); return metadata; } @@ -646,13 +749,11 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( "Failed to find parent path in " + output_file_path); std::string parent_path = output_file_path.substr(0, parent_path_idx); - std::error_code ec{}; - c10::filesystem::create_directories(parent_path, ec); TORCH_CHECK( - ec.value() == 0, + recursive_mkdir(parent_path), "Failed to create directory " + parent_path, ": ", - ec.message()); + c10::utils::str_error(errno)); // Extracts file to the temp directory zip_archive.extract_file(zip_filename_str, output_path_str); @@ -731,8 +832,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( AOTIModelPackageLoader::~AOTIModelPackageLoader() { // Clean up the temporary directory if (!temp_dir_.empty()) { - std::error_code ec; - c10::filesystem::remove_all(temp_dir_, ec); + recursive_rmdir(temp_dir_); } } diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index decdef52a1daa..7dc161d13fd52 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -74,6 +74,22 @@ template struct IsVecMaskType> : std::true_type {}; #endif +template +struct GetScalarType { + using type = T; +}; + +#if INDUCTOR_USE_VECTOR_TYPES() +template +struct GetScalarType> { + using type = T; +}; +template +struct GetScalarType> { + using type = T; +}; +#endif + template struct CascadeSumHelper { // A data struct to help cascade summation: @@ -139,7 +155,7 @@ struct WelfordHelper { // 1. Save the reciprocal of weights to avoid redundant divisions. // 2. Save the welford stack, which is used to combine welford reduction // with cascade summation to improve numerical stability. - static std::vector weight_recps; + static std::vector::type> weight_recps; std::vector> welford_stk{}; uint64_t depth{0}; // depth of welford_stk. uint64_t num_chunks{0}; // number of chunks stored in welford_stk. @@ -154,9 +170,9 @@ struct WelfordHelper { }; template -std::vector WelfordHelper::weight_recps = - []() { - using scalar_t = typename T::value_type; +std::vector::type> + WelfordHelper::weight_recps = []() { + using scalar_t = typename GetScalarType::type; std::vector temp(kChunkSize); for (const auto i : c10::irange(kChunkSize)) { temp[i] = scalar_t(static_cast(1) / static_cast(i + 1)); @@ -202,21 +218,19 @@ Welford welford_combine( // stability. // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance // https://en.wikipedia.org/wiki/Pairwise_summation - if constexpr (IsVecType::value) { - if (w != nullptr && w->depth > 0 && acc.index == kChunkSize) { - w->welford_stk[0] = welford_combine(w->welford_stk[0], acc); - w->num_chunks += 1; - acc.mean = T(0); - acc.m2 = T(0); - acc.weight = T(0); - acc.index = 0; - uint64_t mask = w->num_chunks; - for (uint64_t j = 1; j < w->depth && (mask & 1) == 0; ++j) { - w->welford_stk[j] = - welford_combine(w->welford_stk[j], w->welford_stk[j - 1]); - w->welford_stk[j - 1] = Welford(); - mask >>= 1; - } + if (w != nullptr && w->depth > 0 && acc.index == kChunkSize) { + w->welford_stk[0] = welford_combine(w->welford_stk[0], acc); + w->num_chunks += 1; + acc.mean = T(0); + acc.m2 = T(0); + acc.weight = T(0); + acc.index = 0; + uint64_t mask = w->num_chunks; + for (uint64_t j = 1; j < w->depth && (mask & 1) == 0; ++j) { + w->welford_stk[j] = + welford_combine(w->welford_stk[j], w->welford_stk[j - 1]); + w->welford_stk[j - 1] = Welford(); + mask >>= 1; } } // Add a single data point @@ -224,22 +238,18 @@ Welford welford_combine( auto new_weight = acc.weight + T(1); auto delta = data - acc.mean; T new_mean; - if constexpr (!IsVecType::value) { - new_mean = acc.mean + delta / new_weight; - } else { - // use new_index to fecth 1 / new_weight to avoid divisions - new_mean = acc.mean + - ((w == nullptr || acc.index >= w->weight_recps.size()) - ? delta / new_weight - : delta * T(w->weight_recps[acc.index])); - } + // use new_index to fecth 1 / new_weight to avoid divisions + new_mean = acc.mean + + ((w == nullptr || acc.index >= w->weight_recps.size()) + ? delta / new_weight + : delta * T(w->weight_recps[acc.index])); auto new_delta = data - new_mean; auto result = Welford{new_mean, acc.m2 + delta * new_delta, new_weight, new_index}; return result; } -template +template Welford welford_combine(Welford& acc, WelfordHelper* w) { for (const auto i : c10::irange(w->depth)) { acc = welford_combine(acc, w->welford_stk[i]); @@ -256,7 +266,7 @@ struct IndexValue { }; #if INDUCTOR_USE_VECTOR_TYPES() -template +template Welford welford_combine( Welford& acc, T& data, diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index da61cd28c1b6f..59916b6763bfa 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -1,4 +1,7 @@ -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) && !defined(USE_ROCM) +// We disable this file from being hipified because there are CUDA drivers hip +// has not implemented yet. Also, we're passing in a cubin file directly, so it +// would take more work to support ROCM anyway. #include #include @@ -13,11 +16,6 @@ #include #include #include - -#if defined(USE_ROCM) -#include -#endif - /** Implements a static launcher for triton compiled CUDA kernels. Given a path to a cubin file, a function name, and some metadata, @@ -58,14 +56,8 @@ const at::cuda::NVRTC& nvrtc() { CUdeviceptr getPointer(PyObject* obj) { CUdeviceptr data_ptr = 0; - if (THPUtils_checkLong(obj)) { -#if defined(USE_ROCM) - data_ptr = reinterpret_cast(THPUtils_unpackUInt64(obj)); -#else data_ptr = THPUtils_unpackUInt64(obj); -#endif - return data_ptr; } if (obj == Py_None) { @@ -81,25 +73,13 @@ CUdeviceptr getPointer(PyObject* obj) { TORCH_CHECK( THPUtils_checkLong(ret), "data_ptr method of Pointer object must return 64-bit int"); - -#if defined(USE_ROCM) - data_ptr = reinterpret_cast(THPUtils_unpackUInt64(ret)); -#else data_ptr = THPUtils_unpackUInt64(ret); -#endif - if (!data_ptr) return data_ptr; CUdeviceptr dev_ptr = 0; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute( - &dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute( &dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); -#endif - return dev_ptr; } @@ -118,15 +98,6 @@ CUfunction loadKernel( } CUmodule mod = nullptr; CUfunction func = nullptr; - -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str())); - AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str())); - int shared_optin = 0; - AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( - &shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device)); - -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str())); AT_CUDA_DRIVER_CHECK( nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str())); @@ -135,9 +106,6 @@ CUfunction loadKernel( &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); - -#endif - // Shared memory logic from triton/third-party/nvidia/backend/driver.c // If we're using more than 48 KB of shared memory, and we have // access to more than 48 KB of shared memory on the device, @@ -156,21 +124,6 @@ CUfunction loadKernel( " Reducing block sizes or `num_stages` may help."); if (sharedMemBytes > SHARED_MEM_STATIC_MAX && shared_optin > SHARED_MEM_STATIC_MAX) { -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared)); - int shared_total = 0, shared_static = 0; - AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( - &shared_total, - hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, - device)); - AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( - &shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); - AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute( - func, - hipFuncAttributeMaxDynamicSharedMemorySize, - shared_optin - shared_static)); - -#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED)); int shared_total = 0, shared_static = 0; @@ -184,7 +137,6 @@ CUfunction loadKernel( func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); -#endif } return func; } @@ -200,27 +152,6 @@ inline void launchKernel( cudaStream_t stream) { // cta_args is always 1 for inductor generated triton kernels, // so we don't need to figure out grid dimension here -#if defined(USE_ROCM) - int device = 0; - AT_CUDA_DRIVER_CHECK(hipGetDevice(&device)); - int warp_size = 0; - AT_CUDA_DRIVER_CHECK( - hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device)); - - AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel( - func, - gridX, - gridY, - gridZ, - warp_size * numWarps, // blockDim.x - 1, // blockDim.y - 1, // blockDim.z - sharedMemBytes, - stream, - args, - nullptr)); - -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( func, gridX, @@ -233,7 +164,6 @@ inline void launchKernel( stream, args, nullptr)); -#endif } template @@ -339,20 +269,11 @@ PyObject* load_kernel(PyObject* self, PyObject* args) { CUdevice device = static_cast(device_ptr); // NOLINT CUfunction func = nullptr; func = loadKernel(filePath, funcName, sharedMemBytes, device); - -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK( - hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func)); - AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( - &n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); - -#else + // Taken from triton/nvidia/backend/driver.c AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func)); AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute( &n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); - -#endif n_spills /= 4; // Return a tuple of CUFunction, n_regs, n_spills return Py_BuildValue( @@ -378,6 +299,7 @@ PyObject* launch_kernel_inner( std::array argStorage = {}; std::array kernelArgs = {}; parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data()); + launchKernel( func, gridX, @@ -464,25 +386,13 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) { Py_RETURN_NONE; } CUcontext pctx = nullptr; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); -#endif - if (!pctx) { // Ensure device context exists CUdevice device = 0; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0)); - AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device)); - AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0)); AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device)); AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx)); - -#endif } CUfunction func = reinterpret_cast(func_ptr); // NOLINT cudaStream_t cudaStream = reinterpret_cast(stream); // NOLINT diff --git a/torch/csrc/inductor/static_cuda_launcher.h b/torch/csrc/inductor/static_cuda_launcher.h index 6f3980172275b..517036b9975e6 100644 --- a/torch/csrc/inductor/static_cuda_launcher.h +++ b/torch/csrc/inductor/static_cuda_launcher.h @@ -1,5 +1,5 @@ #pragma once -#if defined(USE_CUDA) +#if defined(USE_CUDA) && !defined(USE_ROCM) #include #include diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index f191c7daf6e26..8e0d94b59acab 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -679,7 +679,7 @@ Value* emitBuiltinCall( at::ArrayRef args, at::ArrayRef kwargs, const std::optional& self) { - const auto& variants = getAllOperatorsFor(name); + auto variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); // first let's set the graph's version diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index 29964e0918534..31fc483812ab0 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -308,12 +308,6 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const { if (auto custom_class_type = getCustomClass(*name)) { return custom_class_type; } - // Check if the type is a custom class. This is done by checking - // if type_name starts with "torch.classes." - if (name->find("torch.classes.") == 0) { - auto custom_class_type = getCustomClass("__torch__." + *name); - return custom_class_type; - } throw ErrorReport(expr) << "Unknown type name '" << *name << "'"; } diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 16edf669da9be..f1353bd3103cc 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -617,7 +617,7 @@ void AliasDb::analyzeImpl(Node* node) { oss << input->type()->str() << ", "; } oss << "\n\nCandidates:"; - const auto& candidates = getAllOperatorsFor(node->kind()); + auto candidates = getAllOperatorsFor(node->kind()); for (const auto& candidate : candidates) { oss << "\n\t" << candidate->schema(); } diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 4368b3c8191d8..6febed3540526 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1088,7 +1088,7 @@ const FunctionSchema* Node::maybeSchema() const { const Operator* Node::maybeOperator() const { if (!op_) { - const auto& candidates = getAllOperatorsFor(kind()); + auto candidates = getAllOperatorsFor(kind()); for (const auto& candidate : candidates) { if (matches(candidate->schema())) { op_ = candidate.get(); diff --git a/torch/csrc/jit/jit_log.cpp b/torch/csrc/jit/jit_log.cpp index 8adf4c8aab10c..745d397f593c0 100644 --- a/torch/csrc/jit/jit_log.cpp +++ b/torch/csrc/jit/jit_log.cpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include #include #include @@ -113,7 +113,12 @@ void JitLoggingConfig::parse() { bool is_enabled(const char* cfname, JitLoggingLevels level) { const auto& files_to_levels = JitLoggingConfig::getInstance().getFilesToLevels(); - const auto fname_no_ext = c10::filesystem::path(cfname).stem().string(); + std::string fname{cfname}; + fname = c10::detail::StripBasename(fname); + const auto end_index = fname.find_last_of('.') == std::string::npos + ? fname.size() + : fname.find_last_of('.'); + const auto fname_no_ext = fname.substr(0, end_index); const auto it = files_to_levels.find(fname_no_ext); if (it == files_to_levels.end()) { @@ -156,7 +161,7 @@ std::string jit_log_prefix( std::stringstream prefix_ss; prefix_ss << "["; prefix_ss << level << " "; - prefix_ss << c10::filesystem::path(fn).filename() << ":"; + prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":"; prefix_ss << std::setfill('0') << std::setw(3) << l; prefix_ss << "] "; diff --git a/torch/csrc/jit/jit_opt_limit.cpp b/torch/csrc/jit/jit_opt_limit.cpp index 385cbe4acdc95..c4c1a2307659f 100644 --- a/torch/csrc/jit/jit_opt_limit.cpp +++ b/torch/csrc/jit/jit_opt_limit.cpp @@ -1,10 +1,11 @@ +#include #include #include #include #include #include -#include +#include #include #include #include @@ -56,7 +57,9 @@ bool opt_limit(const char* pass_name) { static const std::unordered_map passes_to_opt_limits = parseJITOptLimitOption(opt_limit.value()); - auto pass = c10::filesystem::path(pass_name).stem().string(); + std::string pass{pass_name}; + pass = c10::detail::StripBasename(pass); + pass = c10::detail::ExcludeFileExtension(pass); auto opt_limit_it = passes_to_opt_limits.find(pass); if (opt_limit_it == passes_to_opt_limits.end()) { diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index beb6f89519804..f7d855a515789 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1693,7 +1693,7 @@ void initJITBindings(PyObject* module) { [](const std::string& op_name, const std::string& overload_name) { try { auto symbol = Symbol::fromQualString(op_name); - const auto& operations = getAllOperatorsFor(symbol); + auto operations = getAllOperatorsFor(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { return op->schema(); @@ -1714,7 +1714,7 @@ void initJITBindings(PyObject* module) { const std::string& overload_name) -> std::optional { try { auto symbol = Symbol::fromQualString(op_name); - const auto& operations = getAllOperatorsFor(symbol); + auto operations = getAllOperatorsFor(symbol); bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { @@ -2138,7 +2138,7 @@ void initJITBindings(PyObject* module) { m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - const auto& operations = getAllOperatorsFor(symbol); + auto operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 35dead2a395c9..6f9dec70cddc9 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -53,6 +53,16 @@ struct OperatorRegistry { to_register.clear(); } + const std::vector>& getOperatorsWithLockHeld( + Symbol name) { + registerPendingOperators(); + static std::vector> empty; + auto it = operators.find(name); + if (it != operators.end()) + return it->second; + return empty; + } + public: void registerOperator(Operator&& op) { std::lock_guard guard(lock); @@ -143,14 +153,35 @@ struct OperatorRegistry { return it->second; } - const std::vector>& getOperators(Symbol name) { + // This function returns internal lock-protected state. We need to + // copy it to avoid race conditions. + std::vector> getOperators(Symbol name) { std::lock_guard guard(lock); - registerPendingOperators(); - static std::vector> empty; - auto it = operators.find(name); - if (it != operators.end()) - return it->second; - return empty; + return getOperatorsWithLockHeld(name); + } + + std::vector> getSortedOperators(Symbol name) { + std::lock_guard guard(lock); + const auto& unsortedOps = getOperatorsWithLockHeld(name); + // Depending on the order of registration, aten or jit ops may be + // registered first. This sorting is helpful in cases where + // deterministic (i.e. not dependent on build config) behavior is + // desired; e.g. torch.ops.aten.* uses this function, and tries to + // find the "first" op that matches input args. Without the sorting, + // the "first" op may change depending on registration order. + std::vector> sortedOps; + sortedOps.reserve(unsortedOps.size()); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return op->isC10Op(); }); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return !op->isC10Op(); }); + return sortedOps; } std::vector findSimilarOperators(Symbol input_op) { @@ -387,35 +418,16 @@ void deregisterOperator(const FunctionSchema& schema) { getRegistry().deregisterOperator(schema); } -const std::vector> getAllOperators() { +std::vector> getAllOperators() { return getRegistry().getAllOperators(); } -const std::vector>& getAllOperatorsFor(Symbol name) { +std::vector> getAllOperatorsFor(Symbol name) { return getRegistry().getOperators(name); } std::vector> getAllSortedOperatorsFor(Symbol name) { - const auto& unsortedOps = getAllOperatorsFor(name); - // Depending on the order of registration, aten or jit ops may be - // registered first. This sorting is helpful in cases where - // deterministic (i.e. not dependent on build config) behavior is - // desired; e.g. torch.ops.aten.* uses this function, and tries to - // find the "first" op that matches input args. Without the sorting, - // the "first" op may change depending on registration order. - std::vector> sortedOps; - sortedOps.reserve(unsortedOps.size()); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return op->isC10Op(); }); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return !op->isC10Op(); }); - return sortedOps; + return getRegistry().getSortedOperators(name); } std::shared_ptr findOperatorFor(const c10::OperatorName& full_name) { diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index bde3825f5ea38..6b6972deeebf0 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -260,8 +260,9 @@ struct TORCH_API Operator { TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); -TORCH_API const std::vector> getAllOperators(); -TORCH_API const std::vector>& getAllOperatorsFor( +TORCH_API std::vector> getAllOperators(); +// This function returns a copy for thread safety. +TORCH_API std::vector> getAllOperatorsFor( Symbol name); // Returns operators in the order which OpOverloadPacket resolves them. TORCH_API std::vector> getAllSortedOperatorsFor( diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 74f87e46757ea..b1f0f410f14fe 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -79,7 +79,7 @@ auto compilation_unit = std::make_shared(); const std::optional getInplaceVariant( const FunctionSchema& base_schema) { - auto& inplace_variants = + auto inplace_variants = getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_")); for (const auto& variant : inplace_variants) { diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index 1c4d9ce295a84..ffbb7bb1235a7 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -560,3 +560,21 @@ torch_get_num_threads(uint32_t* out_num_threads) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( { *out_num_threads = static_cast(at::get_num_threads()); }); } + +AOTI_TORCH_EXPORT AOTITorchError +torch_get_const_data_ptr(AtenTensorHandle tensor, const void** ret_data_ptr) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = + torch::aot_inductor::tensor_handle_to_tensor_pointer(tensor); + *ret_data_ptr = t->const_data_ptr(); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_get_mutable_data_ptr(AtenTensorHandle tensor, void** ret_data_ptr) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = + torch::aot_inductor::tensor_handle_to_tensor_pointer(tensor); + *ret_data_ptr = t->mutable_data_ptr(); + }); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 0afa650fe2d7c..99b3b435cf550 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -92,6 +92,17 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_thread_idx(uint32_t* out_thread_idx); AOTI_TORCH_EXPORT AOTITorchError torch_get_num_threads(uint32_t* out_num_threads); +// Get a pointer to the underlying storage data +AOTI_TORCH_EXPORT AOTITorchError torch_get_mutable_data_ptr( + AtenTensorHandle tensor, + void** ret_data_ptr // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError torch_get_const_data_ptr( + AtenTensorHandle tensor, + const void** ret_data_ptr // returns borrowed reference +); + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 #ifdef __cplusplus diff --git a/torch/csrc/stable/tensor_inl.h b/torch/csrc/stable/tensor_inl.h index 8eb69f1a63b74..8f7be7b4aabbd 100644 --- a/torch/csrc/stable/tensor_inl.h +++ b/torch/csrc/stable/tensor_inl.h @@ -33,4 +33,37 @@ inline Device Tensor::device() const { return Device(extension_device_type, static_cast(device_index)); } +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 +// The following data ptr cast methods mirror the methods defined in +// aten/src/ATen/templates/TensorMethods.cpp +#define DEFINE_DATA_PTR_CAST(T, name, PRED) \ + template <> \ + inline T* Tensor::mutable_data_ptr() const { \ + auto stype = scalar_type(); \ + STD_TORCH_CHECK( \ + PRED(stype, torch::headeronly::ScalarType::name), \ + "expected scalar type " #name " but found ", \ + torch::headeronly::toString(stype)); \ + return static_cast(mutable_data_ptr()); \ + } \ + template <> \ + inline const T* Tensor::const_data_ptr() const { \ + auto stype = scalar_type(); \ + STD_TORCH_CHECK( \ + PRED(stype, torch::headeronly::ScalarType::name), \ + "expected scalar type " #name " but found ", \ + torch::headeronly::toString(stype)); \ + return static_cast(const_data_ptr()); \ + } + +#define _PRED(S1, S2) S1 == S2 +#define DEFINE_CAST(T, name) DEFINE_DATA_PTR_CAST(T, name, _PRED) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) +DEFINE_CAST(uint16_t, UInt16) +DEFINE_CAST(uint32_t, UInt32) +DEFINE_CAST(uint64_t, UInt64) +#undef DEFINE_CAST +#undef _PRED +#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + HIDDEN_NAMESPACE_END(torch, stable) diff --git a/torch/csrc/stable/tensor_struct.h b/torch/csrc/stable/tensor_struct.h index e3f50ad26781c..78a7430f555d4 100644 --- a/torch/csrc/stable/tensor_struct.h +++ b/torch/csrc/stable/tensor_struct.h @@ -78,12 +78,34 @@ class Tensor { // semantics as their counterparts in TensorBase.h. // ============================================================================= + // Do not add new uses of data_ptr(), use const_data_ptr() if + // possible, mutable_data_ptr() otherwise. void* data_ptr() const { void* data_ptr; TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); return data_ptr; } +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + void* mutable_data_ptr() const { + void* data_ptr{}; + TORCH_ERROR_CODE_CHECK(torch_get_mutable_data_ptr(ath_.get(), &data_ptr)); + return data_ptr; + } + + const void* const_data_ptr() const { + const void* data_ptr{}; + TORCH_ERROR_CODE_CHECK(torch_get_const_data_ptr(ath_.get(), &data_ptr)); + return data_ptr; + } + + template + T* mutable_data_ptr() const; + + template , int> = 0> + const T* const_data_ptr() const; +#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + int64_t dim() const { int64_t dim; TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); diff --git a/torch/csrc/utils/pyobject_preservation.cpp b/torch/csrc/utils/pyobject_preservation.cpp index 4f2d0a2507011..a652cbdb7aefd 100644 --- a/torch/csrc/utils/pyobject_preservation.cpp +++ b/torch/csrc/utils/pyobject_preservation.cpp @@ -1,19 +1,67 @@ #include -#include - -void clear_slots(PyTypeObject* type, PyObject* self) { - Py_ssize_t n = Py_SIZE(type); - PyMemberDef* mp = type->tp_members; - - for (Py_ssize_t i = 0; i < n; i++, mp++) { - if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) { - char* addr = (char*)self + mp->offset; - PyObject* obj = *(PyObject**)addr; - if (obj != nullptr) { - *(PyObject**)addr = nullptr; - Py_DECREF(obj); - } +#include +#include + +namespace torch::utils { + +using c10::intrusive_ptr_target; +using c10::impl::PyObjectSlot; + +void PyObjectPreservation::init_fresh_nonatomic( + intrusive_ptr_target* target, + PyObjectSlot* slot, + PyObject* pyobj) { + TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr); + TORCH_INTERNAL_ASSERT( + target->combined_refcount_.load(std::memory_order_relaxed) == + c10::detail::kUniqueRef); + + slot->pyobj_.store(pyobj, std::memory_order_relaxed); + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed); + target->combined_refcount_.store( + c10::detail::kHasPyObject | c10::detail::kUniqueRef, + std::memory_order_relaxed); +} + +PyObject* PyObjectPreservation::init_once( + intrusive_ptr_target* target, + PyObjectSlot* slot, + PyObject* pyobj) { + PyObject* expected = nullptr; + if (!slot->pyobj_.compare_exchange_strong( + expected, pyobj, std::memory_order_acq_rel)) { + TORCH_INTERNAL_ASSERT(expected != nullptr); + return expected; + } + + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_release); + + bool increfed = false; + auto combined = target->combined_refcount_.load(std::memory_order_relaxed); + do { + TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined)); + if (c10::detail::refcount(combined) > 1 && !increfed) { + // We need to incref the object to preserve the invariant that + // if refcount > 1, the c10 object holds a reference to the PyObject. + // This must happen before we set the kHasPyObject bit. + Py_INCREF(pyobj); + increfed = true; } + } while (!target->combined_refcount_.compare_exchange_weak( + combined, + combined | c10::detail::kHasPyObject, + std::memory_order_acq_rel, + std::memory_order_relaxed)); + + if (increfed && c10::detail::refcount(combined) == 1) { + // Fix up if refcount if we did the incref in a failed compare-exchange + Py_DECREF(pyobj); } + + return pyobj; } + +} // namespace torch::utils diff --git a/torch/csrc/utils/pyobject_preservation.h b/torch/csrc/utils/pyobject_preservation.h index 456095d7b7037..b060bc034b2c3 100644 --- a/torch/csrc/utils/pyobject_preservation.h +++ b/torch/csrc/utils/pyobject_preservation.h @@ -4,4 +4,28 @@ // This file contains utilities used for handling PyObject preservation -void clear_slots(PyTypeObject* type, PyObject* self); +namespace c10 { +class intrusive_ptr_target; +namespace impl { +struct PyObjectSlot; +} // namespace impl +} // namespace c10 + +namespace torch::utils { + +class PyObjectPreservation { + public: + // Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold + // a unique reference to `target`. + static void init_fresh_nonatomic( + c10::intrusive_ptr_target* target, + c10::impl::PyObjectSlot* slot, + PyObject* pyobj); + + static PyObject* init_once( + c10::intrusive_ptr_target* target, + c10::impl::PyObjectSlot* slot, + PyObject* pyobj); +}; + +} // namespace torch::utils diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 6fc3cc1d4e670..95e9509cdbcd6 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -205,7 +205,7 @@ def __init__(self, strategies: list[OpSpec]) -> None: def __str__(self) -> str: strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) mesh_shape = self.mesh_shape - return f"OpStragety[{strategy_list_str}] @ mesh: {mesh_shape}" + return f"OpStrategy[{strategy_list_str}] @ mesh: {mesh_shape}" def max_num_shards(self) -> int: """ diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bacc95d4c9154..44d33e3f73ac4 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -470,6 +470,10 @@ def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> b return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr] +@deprecated( + "guard_size_oblivious will be removed. Consider using explicit unbacked handling \ + potentially utilizing guard_or_false, guard_or_true, or statically_known_true" +) def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: """ Perform a guard on a symbolic boolean expression in a size oblivious way. diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index e475a5bc9b6df..3da33923d5363 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -576,17 +576,6 @@ def go(node, keypath): if i0 in constrained_unbacked_symbols: continue # constrain symbol just once - if i0 in shape_env.size_like: - if export: - graph.call_function( - torch.ops.aten.sym_constrain_range_for_size.default, - (expr_to_proxy[i0].node,), - ) - else: - graph.call_function( - torch._check_is_size, (expr_to_proxy[i0].node,) - ) - vr = shape_env.var_to_range[i0] if vr.is_int and vr.upper == sys.maxsize - 1: # treat upper bound == sys.maxsize - 1 for int symbols as +oo diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h index 5c395e5d5aa29..ce43ce6866cd9 100644 --- a/torch/headeronly/core/ScalarType.h +++ b/torch/headeronly/core/ScalarType.h @@ -336,6 +336,13 @@ inline std::ostream& operator<<( return stream << toString(scalar_type); } +inline bool isQIntType(ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || + t == ScalarType::QUInt2x4; +} + inline ScalarType toUnderlying(ScalarType t) { switch (t) { case ScalarType::QUInt8: @@ -362,6 +369,7 @@ using c10::NumScalarTypes; using c10::ScalarType; using c10::toString; using c10::operator<<; +using c10::isQIntType; using c10::toUnderlying; namespace impl { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2f59b520a8b43..6724ab2ae739a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14264,15 +14264,11 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ], ), BinaryUfuncInfo('logaddexp', dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), - dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.float16, torch.complex32), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, - supports_rhs_python_scalar=False, - skips=( - # TODO: FIXME: RuntimeError: not implemented for 'ComplexFloat' - DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), - )), + supports_rhs_python_scalar=False), OpInfo('logaddexp2', dtypes=floating_types_and(torch.bfloat16, torch.half), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), @@ -23643,10 +23639,12 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): torch_opinfo_name="logaddexp", skips=( # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='cpu', - dtypes=(torch.complex64, torch.complex128)), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='cpu', - dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), ), ), PythonRefInfo( diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index fb53ce4439afd..6ce7d4b2ca507 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -43,7 +43,6 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( - ACCELERATOR_DIST_BACKENDS, MultiProcContinuousTest, MultiProcessTestCase, MultiThreadedTestCase, @@ -397,17 +396,14 @@ def build_device_mesh(self) -> DeviceMesh: return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init, backend: Optional[str] = None) -> None: - if backend is None: - backend = self.backend - - requires_gpu = any( - gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS - ) - if requires_gpu and torch.accelerator.device_count() < self.world_size: + if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) curr_backend = dist.get_default_backend_for_device(self.device_type) + if backend is None: + backend = self.backend + if backend not in [ "nccl", "gloo", diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index b74e4d01da060..e1e38e0c36959 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -106,7 +106,7 @@ class DefaultDeviceType: to save and restore for recomputation. """ - _default_device_type = "cuda" + _default_device_type: Optional[str] = None @staticmethod def set_device_type(device: str = "cuda") -> None: @@ -126,6 +126,9 @@ def get_device_type() -> str: Returns: str: The current default device type. """ + if not DefaultDeviceType._default_device_type: + DefaultDeviceType._default_device_type = acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" + return DefaultDeviceType._default_device_type