From f8c3ea23f3769978d406864045b8f58f219cf6b3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 8 Oct 2025 16:04:32 -0700 Subject: [PATCH 1/2] update cuda delegate resource free pipeline for safety and segfault-free Pull Request resolved: https://github.com/pytorch/executorch/pull/14905 This diff survives `clear_all_tensors()` function and enable it during backend destroy stage. Furthermore, we defer the container handle deletion to OS to avoid potential segfault if there's more than one .so files. ghstack-source-id: 314984329 @exported-using-ghexport Differential Revision: [D84135792](https://our.internmc.facebook.com/intern/diff/D84135792/) --- backends/cuda/runtime/cuda_backend.cpp | 30 +++++++------------------- backends/cuda/runtime/shims/memory.cpp | 15 +++++++++---- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 653687fccb7..e10322ad40c 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -286,18 +286,6 @@ class ET_EXPERIMENTAL CudaBackend final i); } - // Clean up GPU tensors that we created (ExecuTorch tensors are always - // CPU, so all GPU tensors are our copies) - for (int i = 0; i < n_inputs; i++) { - // All GPU input tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_inputs[i]); - } - - for (int i = 0; i < n_outputs; i++) { - // All GPU output tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_outputs[i]); - } - return Error::Ok; } @@ -318,16 +306,13 @@ class ET_EXPERIMENTAL CudaBackend final handle->cuda_stream = nullptr; } - // Delete the container BEFORE closing the shared library - if (handle->container_handle != nullptr) { - AOTIRuntimeError delete_result = - AOTInductorModelContainerDelete(handle->container_handle); - ET_CHECK_OR_LOG_ERROR( - delete_result == Error::Ok, - "Failed to delete AOTInductorModelContainer with error code %d", - delete_result); - handle->container_handle = nullptr; - } + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO(gasoonjia): Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); // Now close the shared library if (handle->so_handle != nullptr) { @@ -346,6 +331,7 @@ class ET_EXPERIMENTAL CudaBackend final } delete handle; + clear_all_tensors(); } }; diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index cbaca68576e..a054169330b 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -271,14 +271,21 @@ void clear_all_tensors() { // Use aoti_torch_delete_tensor_object to properly delete each tensor // Note: We need to collect tensor pointers first since deletion modifies the // set - auto old_tensors = - std::move(tensors); // tensors is now empty and no need to copy - for (const auto& tensor_shared : old_tensors) { - aoti_torch_delete_tensor_object(tensor_shared.get()); + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); } // tensors set should now be empty, but ensure it's cleared tensors.clear(); + + ET_LOG(Info, "Cleared all tensors"); } AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { From 244d262b02c338de5cbb13ce1e404129f4c2e6a1 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 9 Oct 2025 01:17:03 -0400 Subject: [PATCH 2/2] remove extra libtorch dependency (#14928) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: https://github.com/pytorch/executorch/pull/14919 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/51/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/51/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/50/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/51/orig Differential Revision: [D84207378](https://our.internmc.facebook.com/intern/diff/D84207378/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia --- backends/aoti/CMakeLists.txt | 2 +- backends/aoti/common_shims.cpp | 9 ++++++++- backends/aoti/common_shims.h | 2 ++ backends/cuda/CMakeLists.txt | 5 +---- backends/cuda/runtime/cuda_backend.cpp | 11 ++++------- backends/cuda/runtime/guard.cpp | 8 ++------ backends/cuda/runtime/guard.h | 8 ++------ backends/cuda/runtime/shims/cuda_guard.cpp | 8 ++------ backends/cuda/runtime/shims/cuda_guard.h | 8 ++------ backends/cuda/runtime/shims/memory.cpp | 8 ++------ backends/cuda/runtime/shims/memory.h | 8 ++------ backends/cuda/runtime/shims/tensor_attribute.cpp | 8 ++------ backends/cuda/runtime/shims/tensor_attribute.h | 8 ++------ backends/cuda/runtime/utils.h | 8 ++------ 14 files changed, 34 insertions(+), 67 deletions(-) diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index ce364f2c4b0..845144af50f 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -40,7 +40,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) # Ensure symbols are exported properly target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) -# Link against PyTorch libraries and standard libraries +# Link against ExecuTorch libraries and standard libraries target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) executorch_target_link_options_shared_lib(aoti_common) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 2f9b36e3c4f..abc83779443 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() { } // Dtype constants - these return the PyTorch dtype codes -// Currently only float32 is supported, but using robust enum-based approach int32_t aoti_torch_dtype_float32() { return 6; // PyTorch's float32 dtype code } +int32_t aoti_torch_dtype_bfloat16() { + return 15; // PyTorch's bfloat16 dtype code +} + +int32_t aoti_torch_dtype_int64() { + return 4; // PyTorch's int64 dtype code +} + // Cleanup functions void cleanup_tensor_metadata() { internal::tensor_to_sizes.clear(); diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index ffcbaa11a08..5f54cd1c878 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); int32_t aoti_torch_device_type_cpu(); int32_t aoti_torch_layout_strided(); int32_t aoti_torch_dtype_float32(); +int32_t aoti_torch_dtype_bfloat16(); +int32_t aoti_torch_dtype_int64(); // Autograd mode functions int32_t aoti_torch_grad_mode_is_enabled(); diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index dc5b1b786f8..575f676e4cc 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -55,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic) # Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries target_link_libraries( - aoti_cuda - PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} - # Link PyTorch libraries for AOTI CUDA functions - ${TORCH_LIBRARIES} + aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} ) # If you need other CUDA libraries, link them similarly: # target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index e10322ad40c..58ab54e1aac 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -26,9 +26,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { #define LOAD_SYMBOL(name, handle) \ do { \ @@ -335,14 +333,13 @@ class ET_EXPERIMENTAL CudaBackend final } }; -} // namespace cuda +} // namespace executorch::backends::cuda +namespace executorch::backends { namespace { auto cls = cuda::CudaBackend(); executorch::runtime::Backend backend{"CudaBackend", &cls}; static executorch::runtime::Error success_with_compiler = register_backend(backend); } // namespace - -} // namespace backends -} // namespace executorch +} // namespace executorch::backends diff --git a/backends/cuda/runtime/guard.cpp b/backends/cuda/runtime/guard.cpp index 885efc7670d..674cc6387b3 100644 --- a/backends/cuda/runtime/guard.cpp +++ b/backends/cuda/runtime/guard.cpp @@ -9,9 +9,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { namespace { // Thread-local stream storage (private to this file) @@ -146,6 +144,4 @@ Result CUDAStreamGuard::create( return stream_guard; } -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/guard.h b/backends/cuda/runtime/guard.h index 4e5a18a4c0f..3f187000f90 100644 --- a/backends/cuda/runtime/guard.h +++ b/backends/cuda/runtime/guard.h @@ -14,9 +14,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::runtime::Error; using executorch::runtime::Result; @@ -190,6 +188,4 @@ class CUDAStreamGuard { DeviceIndex device_index_; }; -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.cpp b/backends/cuda/runtime/shims/cuda_guard.cpp index 5740d0bf654..bb07acc7ffa 100644 --- a/backends/cuda/runtime/shims/cuda_guard.cpp +++ b/backends/cuda/runtime/shims/cuda_guard.cpp @@ -8,9 +8,7 @@ #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { extern "C" { @@ -104,6 +102,4 @@ AOTITorchError aoti_torch_get_current_cuda_stream( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h index 6da869064a7..f930f3df643 100644 --- a/backends/cuda/runtime/shims/cuda_guard.h +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -13,9 +13,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::backends::aoti::AOTITorchError; @@ -99,6 +97,4 @@ AOTITorchError aoti_torch_get_current_cuda_stream( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index a054169330b..b8e3dc8e21b 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -19,9 +19,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; @@ -659,6 +657,4 @@ AOTITorchError aoti_torch__reinterpret_tensor( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index bcec6621285..7a8d4c3609b 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -12,9 +12,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::backends::aoti::AOTITorchError; using executorch::backends::aoti::Tensor; @@ -145,6 +143,4 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); void clear_all_tensors(); } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.cpp b/backends/cuda/runtime/shims/tensor_attribute.cpp index 5b640b7a9e8..1a14c79f9f2 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.cpp +++ b/backends/cuda/runtime/shims/tensor_attribute.cpp @@ -8,9 +8,7 @@ #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { extern "C" { @@ -31,6 +29,4 @@ int32_t aoti_torch_device_type_cuda() { } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.h b/backends/cuda/runtime/shims/tensor_attribute.h index e99958b4f0c..15a4e397d24 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.h +++ b/backends/cuda/runtime/shims/tensor_attribute.h @@ -12,9 +12,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { // Common using declarations for ExecutorTorch types using executorch::runtime::Error; @@ -35,6 +33,4 @@ int32_t aoti_torch_device_type_cuda(); } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index 02c3abfc83f..2d805724090 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -34,9 +34,7 @@ #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { // Enum for supported data types in et-cuda backend enum class SupportedDTypes : int32_t { @@ -125,6 +123,4 @@ inline AOTITorchError validate_dtype(int32_t dtype) { } } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda