From 13bdf55ec3571f09f816389d3b321e856a17cfb9 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 8 Oct 2025 16:04:31 -0700 Subject: [PATCH 1/2] refactor cuda_backend.cpp Pull Request resolved: https://github.com/pytorch/executorch/pull/14904 This diff does a comprehensive refactor on cuda_backend.cpp. Two main points: 1. Reuse ExecuTorch standard macros (ET_CHECK_OR_RETURN_ERROR and others) to replaces exiting if..else + ET_LOG branches 2. Introduced LOAD_SYMBOL macro to concentrate the symbol loading pipeline. ghstack-source-id: 314984328 @exported-using-ghexport Differential Revision: [D84135844](https://our.internmc.facebook.com/intern/diff/D84135844/) --- backends/aoti/targets.bzl | 2 +- backends/cuda/runtime/TARGETS | 22 ++++ backends/cuda/runtime/cuda_backend.cpp | 174 ++++++++++--------------- 3 files changed, 91 insertions(+), 107 deletions(-) diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index 79f082e5a89..8bf44573bb3 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -51,7 +51,7 @@ def define_common_targets(): link_whole = True, supports_python_dlopen = True, visibility = ["@EXECUTORCH_CLIENTS"], - deps = [ + exported_deps = [ ":common_shims", ":model_container", ], diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 0386b5a008d..54412269287 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -34,3 +34,25 @@ runtime.cxx_library( ("cuda", None, "cuda-lazy"), ], ) + +runtime.cxx_library( + name = "cuda_backend", + srcs = [ + "cuda_backend.cpp", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + ":runtime_shims", + "//executorch/backends/aoti:aoti_common", + "//executorch/runtime/backend:interface", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 5f113b1ce68..653687fccb7 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -30,6 +30,13 @@ namespace executorch { namespace backends { namespace cuda { +#define LOAD_SYMBOL(name, handle) \ + do { \ + name = reinterpret_cast(dlsym(handle, #name)); \ + ET_CHECK_OR_RETURN_ERROR( \ + name != nullptr, AccessFailed, "Failed to load " #name); \ + } while (0) + using namespace std; using namespace aoti; @@ -53,45 +60,11 @@ class ET_EXPERIMENTAL CudaBackend final : public ::executorch::runtime::BackendInterface { private: Error register_shared_library_functions(void* so_handle) const { - AOTInductorModelContainerCreateWithDevice = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice")); - if (AOTInductorModelContainerCreateWithDevice == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerCreateWithDevice"); - return Error::AccessFailed; - } - - AOTInductorModelContainerDelete = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerDelete")); - if (AOTInductorModelContainerDelete == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerDelete"); - return Error::AccessFailed; - } - - AOTInductorModelContainerGetNumInputs = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerGetNumInputs")); - if (AOTInductorModelContainerGetNumInputs == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumInputs"); - return Error::AccessFailed; - } - - AOTInductorModelContainerGetNumOutputs = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs")); - if (AOTInductorModelContainerGetNumOutputs == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumOutputs"); - return Error::AccessFailed; - } - - AOTInductorModelContainerRun = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerRun")); - if (AOTInductorModelContainerRun == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerRun"); - return Error::AccessFailed; - } + LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle); return Error::Ok; } @@ -122,14 +95,13 @@ class ET_EXPERIMENTAL CudaBackend final const NamedDataMap* named_data_map = context.get_named_data_map(); auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str()); - if (!aoti_cuda_buffer.ok()) { - ET_LOG( - Error, - "Failed to get data for key %s: 0x%x", - so_blob_key.c_str(), - aoti_cuda_buffer.error()); - return aoti_cuda_buffer.error(); - } + ET_CHECK_OR_RETURN_ERROR( + aoti_cuda_buffer.ok(), + Internal, + "Failed to get data for key %s: 0x%x", + so_blob_key.c_str(), + static_cast(aoti_cuda_buffer.error())); + // Generate dynamic temporary file path filesystem::path temp_dir = filesystem::temp_directory_path(); filesystem::path so_path = @@ -144,39 +116,35 @@ class ET_EXPERIMENTAL CudaBackend final "Writing %zu bytes to %s", aoti_cuda_buffer->size(), so_path.c_str()); + outfile.write( static_cast(aoti_cuda_buffer->data()), aoti_cuda_buffer->size()); - if (!outfile) { - ET_LOG(Error, "Failed to write to file %s", so_path.c_str()); - return Error::AccessFailed; - } + ET_CHECK_OR_RETURN_ERROR( + outfile, AccessFailed, "Failed to write to file %s", so_path.c_str()); + // Finish writing the file to disk outfile.close(); // Load the ELF using dlopen void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (so_handle == nullptr) { - ET_LOG(Error, "Failed to load shared library: %s", dlerror()); - return Error::AccessFailed; - } + ET_CHECK_OR_RETURN_ERROR( + so_handle != nullptr, + AccessFailed, + "Failed to load shared library: %s", + dlerror()); processed->Free(); // Register all shared library functions - Error reg_err = register_shared_library_functions(so_handle); - if (reg_err != Error::Ok) { - return reg_err; - } + ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle)); AOTInductorModelContainerHandle container_handle = nullptr; - AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice( - &container_handle, 1, "cuda", nullptr); - if (err != Error::Ok) { - return err; - } + ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice( + &container_handle, 1, "cuda", nullptr)); + ET_LOG(Info, "container_handle = %p", container_handle); AOTIDelegateHandle* handle = new AOTIDelegateHandle(); @@ -206,15 +174,13 @@ class ET_EXPERIMENTAL CudaBackend final AOTInductorModelContainerGetNumOutputs( handle->container_handle, &n_outputs); - if (n_inputs + n_outputs != args.size()) { - ET_LOG( - Error, - "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", - n_inputs, - n_outputs, - args.size()); - return Error::InvalidArgument; - } + ET_CHECK_OR_RETURN_ERROR( + n_inputs + n_outputs == args.size(), + InvalidArgument, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()) // NOTE: ExecuTorch tensors are always on CPU/host memory // We need to create GPU copies for CUDA kernel execution @@ -244,19 +210,20 @@ class ET_EXPERIMENTAL CudaBackend final 0, // device_index = 0 &gpu_input_handle); - if (create_err != Error::Ok) { - ET_LOG(Error, "Failed to create GPU tensor for input %d", i); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + create_err == Error::Ok, + Internal, + "Failed to create GPU tensor for input %d", + i); gpu_inputs[i] = gpu_input_handle; // Copy data from CPU to GPU - Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0); - if (copy_err != Error::Ok) { - ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok, + Internal, + "Failed to copy input %d from CPU to GPU", + i); } ET_LOG(Info, "Inputs copied to GPU"); // Process output tensors: create GPU counterparts for ExecuTorch CPU @@ -280,10 +247,11 @@ class ET_EXPERIMENTAL CudaBackend final 0, // device_index = 0 &gpu_output_handle); - if (create_err != Error::Ok) { - ET_LOG(Error, "Failed to create GPU tensor for output %d", i); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + create_err == Error::Ok, + Internal, + "Failed to create GPU tensor for output %d", + i); gpu_outputs[i] = gpu_output_handle; } @@ -298,13 +266,11 @@ class ET_EXPERIMENTAL CudaBackend final handle->cuda_stream, // Pass the actual CUDA stream nullptr); // proxy_executor_handle can remain nullptr - if (error != Error::Ok) { - ET_LOG( - Error, - "AOTInductorModelContainerRun failed with error code %d", - error); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + error == Error::Ok, + Internal, + "AOTInductorModelContainerRun failed with error code %d", + error); // Copy GPU output results back to CPU output tensors for (int i = 0; i < n_outputs; i++) { @@ -356,12 +322,10 @@ class ET_EXPERIMENTAL CudaBackend final if (handle->container_handle != nullptr) { AOTIRuntimeError delete_result = AOTInductorModelContainerDelete(handle->container_handle); - if (delete_result != Error::Ok) { - ET_LOG( - Error, - "AOTInductorModelContainerDelete failed with error code %d", - delete_result); - } + ET_CHECK_OR_LOG_ERROR( + delete_result == Error::Ok, + "Failed to delete AOTInductorModelContainer with error code %d", + delete_result); handle->container_handle = nullptr; } @@ -374,13 +338,11 @@ class ET_EXPERIMENTAL CudaBackend final if (!handle->so_path.empty()) { std::error_code remove_error; std::filesystem::remove(handle->so_path, remove_error); - if (remove_error) { - ET_LOG( - Error, - "Failed to remove temporary shared library %s: %s", - handle->so_path.c_str(), - remove_error.message().c_str()); - } + ET_CHECK_OR_LOG_ERROR( + !remove_error, + "Failed to remove temporary shared library %s: %s", + handle->so_path.c_str(), + remove_error.message().c_str()); } delete handle; From a97aac33a7801404a7cb0acaf1db885db7ca5b0e Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 9 Oct 2025 01:17:29 -0400 Subject: [PATCH 2/2] update cuda delegate resource free pipeline for safety and segfault-free (#14927) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: https://github.com/pytorch/executorch/pull/14905 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/50/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/50/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/49/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/50/orig Differential Revision: [D84135792](https://our.internmc.facebook.com/intern/diff/D84135792/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia --- backends/aoti/CMakeLists.txt | 2 +- backends/aoti/common_shims.cpp | 9 +++- backends/aoti/common_shims.h | 2 + backends/cuda/CMakeLists.txt | 5 +-- backends/cuda/runtime/cuda_backend.cpp | 41 ++++++------------- backends/cuda/runtime/guard.cpp | 8 +--- backends/cuda/runtime/guard.h | 8 +--- backends/cuda/runtime/shims/cuda_guard.cpp | 8 +--- backends/cuda/runtime/shims/cuda_guard.h | 8 +--- backends/cuda/runtime/shims/memory.cpp | 23 ++++++----- backends/cuda/runtime/shims/memory.h | 8 +--- .../cuda/runtime/shims/tensor_attribute.cpp | 8 +--- .../cuda/runtime/shims/tensor_attribute.h | 8 +--- backends/cuda/runtime/utils.h | 8 +--- 14 files changed, 53 insertions(+), 93 deletions(-) diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index ce364f2c4b0..845144af50f 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -40,7 +40,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) # Ensure symbols are exported properly target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) -# Link against PyTorch libraries and standard libraries +# Link against ExecuTorch libraries and standard libraries target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) executorch_target_link_options_shared_lib(aoti_common) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 2f9b36e3c4f..abc83779443 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() { } // Dtype constants - these return the PyTorch dtype codes -// Currently only float32 is supported, but using robust enum-based approach int32_t aoti_torch_dtype_float32() { return 6; // PyTorch's float32 dtype code } +int32_t aoti_torch_dtype_bfloat16() { + return 15; // PyTorch's bfloat16 dtype code +} + +int32_t aoti_torch_dtype_int64() { + return 4; // PyTorch's int64 dtype code +} + // Cleanup functions void cleanup_tensor_metadata() { internal::tensor_to_sizes.clear(); diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index ffcbaa11a08..5f54cd1c878 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); int32_t aoti_torch_device_type_cpu(); int32_t aoti_torch_layout_strided(); int32_t aoti_torch_dtype_float32(); +int32_t aoti_torch_dtype_bfloat16(); +int32_t aoti_torch_dtype_int64(); // Autograd mode functions int32_t aoti_torch_grad_mode_is_enabled(); diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index dc5b1b786f8..575f676e4cc 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -55,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic) # Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries target_link_libraries( - aoti_cuda - PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} - # Link PyTorch libraries for AOTI CUDA functions - ${TORCH_LIBRARIES} + aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} ) # If you need other CUDA libraries, link them similarly: # target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 653687fccb7..58ab54e1aac 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -26,9 +26,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { #define LOAD_SYMBOL(name, handle) \ do { \ @@ -286,18 +284,6 @@ class ET_EXPERIMENTAL CudaBackend final i); } - // Clean up GPU tensors that we created (ExecuTorch tensors are always - // CPU, so all GPU tensors are our copies) - for (int i = 0; i < n_inputs; i++) { - // All GPU input tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_inputs[i]); - } - - for (int i = 0; i < n_outputs; i++) { - // All GPU output tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_outputs[i]); - } - return Error::Ok; } @@ -318,16 +304,13 @@ class ET_EXPERIMENTAL CudaBackend final handle->cuda_stream = nullptr; } - // Delete the container BEFORE closing the shared library - if (handle->container_handle != nullptr) { - AOTIRuntimeError delete_result = - AOTInductorModelContainerDelete(handle->container_handle); - ET_CHECK_OR_LOG_ERROR( - delete_result == Error::Ok, - "Failed to delete AOTInductorModelContainer with error code %d", - delete_result); - handle->container_handle = nullptr; - } + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO(gasoonjia): Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); // Now close the shared library if (handle->so_handle != nullptr) { @@ -346,17 +329,17 @@ class ET_EXPERIMENTAL CudaBackend final } delete handle; + clear_all_tensors(); } }; -} // namespace cuda +} // namespace executorch::backends::cuda +namespace executorch::backends { namespace { auto cls = cuda::CudaBackend(); executorch::runtime::Backend backend{"CudaBackend", &cls}; static executorch::runtime::Error success_with_compiler = register_backend(backend); } // namespace - -} // namespace backends -} // namespace executorch +} // namespace executorch::backends diff --git a/backends/cuda/runtime/guard.cpp b/backends/cuda/runtime/guard.cpp index 885efc7670d..674cc6387b3 100644 --- a/backends/cuda/runtime/guard.cpp +++ b/backends/cuda/runtime/guard.cpp @@ -9,9 +9,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { namespace { // Thread-local stream storage (private to this file) @@ -146,6 +144,4 @@ Result CUDAStreamGuard::create( return stream_guard; } -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/guard.h b/backends/cuda/runtime/guard.h index 4e5a18a4c0f..3f187000f90 100644 --- a/backends/cuda/runtime/guard.h +++ b/backends/cuda/runtime/guard.h @@ -14,9 +14,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::runtime::Error; using executorch::runtime::Result; @@ -190,6 +188,4 @@ class CUDAStreamGuard { DeviceIndex device_index_; }; -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.cpp b/backends/cuda/runtime/shims/cuda_guard.cpp index 5740d0bf654..bb07acc7ffa 100644 --- a/backends/cuda/runtime/shims/cuda_guard.cpp +++ b/backends/cuda/runtime/shims/cuda_guard.cpp @@ -8,9 +8,7 @@ #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { extern "C" { @@ -104,6 +102,4 @@ AOTITorchError aoti_torch_get_current_cuda_stream( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h index 6da869064a7..f930f3df643 100644 --- a/backends/cuda/runtime/shims/cuda_guard.h +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -13,9 +13,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::backends::aoti::AOTITorchError; @@ -99,6 +97,4 @@ AOTITorchError aoti_torch_get_current_cuda_stream( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index cbaca68576e..b8e3dc8e21b 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -19,9 +19,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; @@ -271,14 +269,21 @@ void clear_all_tensors() { // Use aoti_torch_delete_tensor_object to properly delete each tensor // Note: We need to collect tensor pointers first since deletion modifies the // set - auto old_tensors = - std::move(tensors); // tensors is now empty and no need to copy - for (const auto& tensor_shared : old_tensors) { - aoti_torch_delete_tensor_object(tensor_shared.get()); + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); } // tensors set should now be empty, but ensure it's cleared tensors.clear(); + + ET_LOG(Info, "Cleared all tensors"); } AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { @@ -652,6 +657,4 @@ AOTITorchError aoti_torch__reinterpret_tensor( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index bcec6621285..7a8d4c3609b 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -12,9 +12,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::backends::aoti::AOTITorchError; using executorch::backends::aoti::Tensor; @@ -145,6 +143,4 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); void clear_all_tensors(); } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.cpp b/backends/cuda/runtime/shims/tensor_attribute.cpp index 5b640b7a9e8..1a14c79f9f2 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.cpp +++ b/backends/cuda/runtime/shims/tensor_attribute.cpp @@ -8,9 +8,7 @@ #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { extern "C" { @@ -31,6 +29,4 @@ int32_t aoti_torch_device_type_cuda() { } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.h b/backends/cuda/runtime/shims/tensor_attribute.h index e99958b4f0c..15a4e397d24 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.h +++ b/backends/cuda/runtime/shims/tensor_attribute.h @@ -12,9 +12,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { // Common using declarations for ExecutorTorch types using executorch::runtime::Error; @@ -35,6 +33,4 @@ int32_t aoti_torch_device_type_cuda(); } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index 02c3abfc83f..2d805724090 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -34,9 +34,7 @@ #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { // Enum for supported data types in et-cuda backend enum class SupportedDTypes : int32_t { @@ -125,6 +123,4 @@ inline AOTITorchError validate_dtype(int32_t dtype) { } } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda