diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index 8d49bcf1f96..845144af50f 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -40,13 +40,8 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) # Ensure symbols are exported properly target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) -# Link against PyTorch libraries and standard libraries -target_link_libraries( - aoti_common - PUBLIC extension_tensor ${CMAKE_DL_LIBS} - # Link PyTorch libraries for AOTI functions - ${TORCH_LIBRARIES} -) +# Link against ExecuTorch libraries and standard libraries +target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) executorch_target_link_options_shared_lib(aoti_common) install( diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 844bd2d5a77..9b185327172 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -77,6 +77,8 @@ struct AOTIDelegateHandle { void* so_handle; std::string so_path; AOTInductorModelContainerHandle container_handle; + void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header + // dependency }; } // namespace aoti diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 2f9b36e3c4f..abc83779443 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() { } // Dtype constants - these return the PyTorch dtype codes -// Currently only float32 is supported, but using robust enum-based approach int32_t aoti_torch_dtype_float32() { return 6; // PyTorch's float32 dtype code } +int32_t aoti_torch_dtype_bfloat16() { + return 15; // PyTorch's bfloat16 dtype code +} + +int32_t aoti_torch_dtype_int64() { + return 4; // PyTorch's int64 dtype code +} + // Cleanup functions void cleanup_tensor_metadata() { internal::tensor_to_sizes.clear(); diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index ffcbaa11a08..5f54cd1c878 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); int32_t aoti_torch_device_type_cpu(); int32_t aoti_torch_layout_strided(); int32_t aoti_torch_dtype_float32(); +int32_t aoti_torch_dtype_bfloat16(); +int32_t aoti_torch_dtype_int64(); // Autograd mode functions int32_t aoti_torch_grad_mode_is_enabled(); diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index 79f082e5a89..8bf44573bb3 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -51,7 +51,7 @@ def define_common_targets(): link_whole = True, supports_python_dlopen = True, visibility = ["@EXECUTORCH_CLIENTS"], - deps = [ + exported_deps = [ ":common_shims", ":model_container", ], diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index acbb7adc87f..575f676e4cc 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) find_package_torch() # CUDA-specific AOTI functionality -set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp - runtime/shims/tensor_attribute.cpp runtime/guard.cpp +set(_aoti_cuda_sources + runtime/cuda_backend.cpp runtime/shims/memory.cpp + runtime/shims/tensor_attribute.cpp runtime/guard.cpp + runtime/shims/cuda_guard.cpp ) add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) target_include_directories( @@ -53,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic) # Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries target_link_libraries( - aoti_cuda - PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} - # Link PyTorch libraries for AOTI CUDA functions - ${TORCH_LIBRARIES} + aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} ) # If you need other CUDA libraries, link them similarly: # target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...) diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index c4b778eccc5..54412269287 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -6,11 +6,13 @@ runtime.cxx_library( name = "runtime_shims", srcs = [ "guard.cpp", + "shims/cuda_guard.cpp", "shims/memory.cpp", "shims/tensor_attribute.cpp", ], headers = [ "guard.h", + "shims/cuda_guard.h", "shims/memory.h", "shims/tensor_attribute.h", "utils.h", @@ -32,3 +34,25 @@ runtime.cxx_library( ("cuda", None, "cuda-lazy"), ], ) + +runtime.cxx_library( + name = "cuda_backend", + srcs = [ + "cuda_backend.cpp", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + ":runtime_shims", + "//executorch/backends/aoti:aoti_common", + "//executorch/runtime/backend:interface", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 08031ce6a26..58ab54e1aac 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -16,7 +17,6 @@ #include #include -#include #include #include @@ -24,10 +24,16 @@ #include #include #include +#include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { + +#define LOAD_SYMBOL(name, handle) \ + do { \ + name = reinterpret_cast(dlsym(handle, #name)); \ + ET_CHECK_OR_RETURN_ERROR( \ + name != nullptr, AccessFailed, "Failed to load " #name); \ + } while (0) using namespace std; using namespace aoti; @@ -52,45 +58,11 @@ class ET_EXPERIMENTAL CudaBackend final : public ::executorch::runtime::BackendInterface { private: Error register_shared_library_functions(void* so_handle) const { - AOTInductorModelContainerCreateWithDevice = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice")); - if (AOTInductorModelContainerCreateWithDevice == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerCreateWithDevice"); - return Error::AccessFailed; - } - - AOTInductorModelContainerDelete = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerDelete")); - if (AOTInductorModelContainerDelete == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerDelete"); - return Error::AccessFailed; - } - - AOTInductorModelContainerGetNumInputs = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerGetNumInputs")); - if (AOTInductorModelContainerGetNumInputs == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumInputs"); - return Error::AccessFailed; - } - - AOTInductorModelContainerGetNumOutputs = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs")); - if (AOTInductorModelContainerGetNumOutputs == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumOutputs"); - return Error::AccessFailed; - } - - AOTInductorModelContainerRun = - reinterpret_cast( - dlsym(so_handle, "AOTInductorModelContainerRun")); - if (AOTInductorModelContainerRun == nullptr) { - ET_LOG(Error, "Failed to load AOTInductorModelContainerRun"); - return Error::AccessFailed; - } + LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle); + LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle); return Error::Ok; } @@ -121,14 +93,13 @@ class ET_EXPERIMENTAL CudaBackend final const NamedDataMap* named_data_map = context.get_named_data_map(); auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str()); - if (!aoti_cuda_buffer.ok()) { - ET_LOG( - Error, - "Failed to get data for key %s: 0x%x", - so_blob_key.c_str(), - aoti_cuda_buffer.error()); - return aoti_cuda_buffer.error(); - } + ET_CHECK_OR_RETURN_ERROR( + aoti_cuda_buffer.ok(), + Internal, + "Failed to get data for key %s: 0x%x", + so_blob_key.c_str(), + static_cast(aoti_cuda_buffer.error())); + // Generate dynamic temporary file path filesystem::path temp_dir = filesystem::temp_directory_path(); filesystem::path so_path = @@ -143,45 +114,47 @@ class ET_EXPERIMENTAL CudaBackend final "Writing %zu bytes to %s", aoti_cuda_buffer->size(), so_path.c_str()); + outfile.write( static_cast(aoti_cuda_buffer->data()), aoti_cuda_buffer->size()); - if (!outfile) { - ET_LOG(Error, "Failed to write to file %s", so_path.c_str()); - return Error::AccessFailed; - } + ET_CHECK_OR_RETURN_ERROR( + outfile, AccessFailed, "Failed to write to file %s", so_path.c_str()); + // Finish writing the file to disk outfile.close(); // Load the ELF using dlopen void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (so_handle == nullptr) { - ET_LOG(Error, "Failed to load shared library: %s", dlerror()); - return Error::AccessFailed; - } + ET_CHECK_OR_RETURN_ERROR( + so_handle != nullptr, + AccessFailed, + "Failed to load shared library: %s", + dlerror()); processed->Free(); // Register all shared library functions - Error reg_err = register_shared_library_functions(so_handle); - if (reg_err != Error::Ok) { - return reg_err; - } + ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle)); AOTInductorModelContainerHandle container_handle = nullptr; - AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice( - &container_handle, 1, "cuda", nullptr); - if (err != Error::Ok) { - return err; - } + ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice( + &container_handle, 1, "cuda", nullptr)); + ET_LOG(Info, "container_handle = %p", container_handle); AOTIDelegateHandle* handle = new AOTIDelegateHandle(); handle->so_handle = so_handle; handle->so_path = so_path.string(); handle->container_handle = container_handle; + + // Create a CUDA stream for asynchronous execution + cudaStream_t cuda_stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); + handle->cuda_stream = static_cast(cuda_stream); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -199,15 +172,13 @@ class ET_EXPERIMENTAL CudaBackend final AOTInductorModelContainerGetNumOutputs( handle->container_handle, &n_outputs); - if (n_inputs + n_outputs != args.size()) { - ET_LOG( - Error, - "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", - n_inputs, - n_outputs, - args.size()); - return Error::InvalidArgument; - } + ET_CHECK_OR_RETURN_ERROR( + n_inputs + n_outputs == args.size(), + InvalidArgument, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()) // NOTE: ExecuTorch tensors are always on CPU/host memory // We need to create GPU copies for CUDA kernel execution @@ -237,19 +208,20 @@ class ET_EXPERIMENTAL CudaBackend final 0, // device_index = 0 &gpu_input_handle); - if (create_err != Error::Ok) { - ET_LOG(Error, "Failed to create GPU tensor for input %d", i); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + create_err == Error::Ok, + Internal, + "Failed to create GPU tensor for input %d", + i); gpu_inputs[i] = gpu_input_handle; // Copy data from CPU to GPU - Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0); - if (copy_err != Error::Ok) { - ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok, + Internal, + "Failed to copy input %d from CPU to GPU", + i); } ET_LOG(Info, "Inputs copied to GPU"); // Process output tensors: create GPU counterparts for ExecuTorch CPU @@ -273,10 +245,11 @@ class ET_EXPERIMENTAL CudaBackend final 0, // device_index = 0 &gpu_output_handle); - if (create_err != Error::Ok) { - ET_LOG(Error, "Failed to create GPU tensor for output %d", i); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + create_err == Error::Ok, + Internal, + "Failed to create GPU tensor for output %d", + i); gpu_outputs[i] = gpu_output_handle; } @@ -288,16 +261,14 @@ class ET_EXPERIMENTAL CudaBackend final n_inputs, gpu_outputs.data(), // Use GPU output tensors n_outputs, - nullptr, // Pass the actual CUDA stream! + handle->cuda_stream, // Pass the actual CUDA stream nullptr); // proxy_executor_handle can remain nullptr - if (error != Error::Ok) { - ET_LOG( - Error, - "AOTInductorModelContainerRun failed with error code %d", - error); - return Error::Internal; - } + ET_CHECK_OR_RETURN_ERROR( + error == Error::Ok, + Internal, + "AOTInductorModelContainerRun failed with error code %d", + error); // Copy GPU output results back to CPU output tensors for (int i = 0; i < n_outputs; i++) { @@ -313,18 +284,6 @@ class ET_EXPERIMENTAL CudaBackend final i); } - // Clean up GPU tensors that we created (ExecuTorch tensors are always - // CPU, so all GPU tensors are our copies) - for (int i = 0; i < n_inputs; i++) { - // All GPU input tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_inputs[i]); - } - - for (int i = 0; i < n_outputs; i++) { - // All GPU output tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_outputs[i]); - } - return Error::Ok; } @@ -334,19 +293,25 @@ class ET_EXPERIMENTAL CudaBackend final } AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; - // Delete the container BEFORE closing the shared library - if (handle->container_handle != nullptr) { - AOTIRuntimeError delete_result = - AOTInductorModelContainerDelete(handle->container_handle); - if (delete_result != Error::Ok) { - ET_LOG( - Error, - "AOTInductorModelContainerDelete failed with error code %d", - delete_result); - } - handle->container_handle = nullptr; + // Destroy the CUDA stream if it exists + if (handle->cuda_stream != nullptr) { + cudaStream_t cuda_stream = static_cast(handle->cuda_stream); + cudaError_t stream_err = cudaStreamDestroy(cuda_stream); + ET_CHECK_OR_LOG_ERROR( + stream_err == cudaSuccess, + "Failed to destroy CUDA stream: %s", + cudaGetErrorString(stream_err)); + handle->cuda_stream = nullptr; } + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO(gasoonjia): Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); + // Now close the shared library if (handle->so_handle != nullptr) { dlclose(handle->so_handle); @@ -356,27 +321,25 @@ class ET_EXPERIMENTAL CudaBackend final if (!handle->so_path.empty()) { std::error_code remove_error; std::filesystem::remove(handle->so_path, remove_error); - if (remove_error) { - ET_LOG( - Error, - "Failed to remove temporary shared library %s: %s", - handle->so_path.c_str(), - remove_error.message().c_str()); - } + ET_CHECK_OR_LOG_ERROR( + !remove_error, + "Failed to remove temporary shared library %s: %s", + handle->so_path.c_str(), + remove_error.message().c_str()); } delete handle; + clear_all_tensors(); } }; -} // namespace cuda +} // namespace executorch::backends::cuda +namespace executorch::backends { namespace { auto cls = cuda::CudaBackend(); executorch::runtime::Backend backend{"CudaBackend", &cls}; static executorch::runtime::Error success_with_compiler = register_backend(backend); } // namespace - -} // namespace backends -} // namespace executorch +} // namespace executorch::backends diff --git a/backends/cuda/runtime/guard.cpp b/backends/cuda/runtime/guard.cpp index 885efc7670d..674cc6387b3 100644 --- a/backends/cuda/runtime/guard.cpp +++ b/backends/cuda/runtime/guard.cpp @@ -9,9 +9,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { namespace { // Thread-local stream storage (private to this file) @@ -146,6 +144,4 @@ Result CUDAStreamGuard::create( return stream_guard; } -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/guard.h b/backends/cuda/runtime/guard.h index 4e5a18a4c0f..3f187000f90 100644 --- a/backends/cuda/runtime/guard.h +++ b/backends/cuda/runtime/guard.h @@ -14,9 +14,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::runtime::Error; using executorch::runtime::Result; @@ -190,6 +188,4 @@ class CUDAStreamGuard { DeviceIndex device_index_; }; -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.cpp b/backends/cuda/runtime/shims/cuda_guard.cpp new file mode 100644 index 00000000000..bb07acc7ffa --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch::backends::cuda { + +extern "C" { + +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_guard failed: ret_guard is null"); + + auto result = CUDAGuard::create(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_cuda_guard_set_index failed: guard is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index)); + return Error::Ok; +} + +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: ret_guard is null"); + + ET_CHECK_OR_RETURN_ERROR( + stream != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: stream is null"); + + auto result = + CUDAStreamGuard::create(static_cast(stream), device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAStreamGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_stream_guard( + CUDAStreamGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_stream_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream) { + ET_CHECK_OR_RETURN_ERROR( + ret_stream != nullptr, + InvalidArgument, + "aoti_torch_get_current_cuda_stream failed: ret_stream is null"); + + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_stream = static_cast(result.get()); + return Error::Ok; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h new file mode 100644 index 00000000000..f930f3df643 --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; + +extern "C" { + +// Handle types for CUDA guards +using CUDAGuardHandle = CUDAGuard*; +using CUDAStreamGuardHandle = CUDAStreamGuard*; + +/** + * Creates a CUDA device guard that sets the current device and restores it + * upon destruction. + * + * @param device_index The device index to set as current + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard); + +/** + * Deletes a CUDA device guard and frees its associated resources. + * + * @param guard Handle to the guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); + +/** + * Sets the CUDA device to a new index for an existing guard. + * + * @param guard Handle to the guard + * @param device_index The device index to set as current + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index); + +/** + * Creates a CUDA stream guard that sets the current device and stream, + * restoring both upon destruction. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard); + +/** + * Deletes a CUDA stream guard and frees its associated resources. + * + * @param guard Handle to the stream guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); + +/** + * Gets the current CUDA stream for a specified device. + * + * @param device_index The device index (-1 to use current device) + * @param ret_stream Output parameter for the current stream (must not be null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index cbaca68576e..b8e3dc8e21b 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -19,9 +19,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; @@ -271,14 +269,21 @@ void clear_all_tensors() { // Use aoti_torch_delete_tensor_object to properly delete each tensor // Note: We need to collect tensor pointers first since deletion modifies the // set - auto old_tensors = - std::move(tensors); // tensors is now empty and no need to copy - for (const auto& tensor_shared : old_tensors) { - aoti_torch_delete_tensor_object(tensor_shared.get()); + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); } // tensors set should now be empty, but ensure it's cleared tensors.clear(); + + ET_LOG(Info, "Cleared all tensors"); } AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { @@ -652,6 +657,4 @@ AOTITorchError aoti_torch__reinterpret_tensor( } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index bcec6621285..7a8d4c3609b 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -12,9 +12,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { using executorch::backends::aoti::AOTITorchError; using executorch::backends::aoti::Tensor; @@ -145,6 +143,4 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); void clear_all_tensors(); } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.cpp b/backends/cuda/runtime/shims/tensor_attribute.cpp index 5b640b7a9e8..1a14c79f9f2 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.cpp +++ b/backends/cuda/runtime/shims/tensor_attribute.cpp @@ -8,9 +8,7 @@ #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { extern "C" { @@ -31,6 +29,4 @@ int32_t aoti_torch_device_type_cuda() { } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.h b/backends/cuda/runtime/shims/tensor_attribute.h index e99958b4f0c..15a4e397d24 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.h +++ b/backends/cuda/runtime/shims/tensor_attribute.h @@ -12,9 +12,7 @@ #include #include -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { // Common using declarations for ExecutorTorch types using executorch::runtime::Error; @@ -35,6 +33,4 @@ int32_t aoti_torch_device_type_cuda(); } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index fcb95a0beb7..70f27b86bec 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -32,3 +32,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") cuda_shim_cpp_unittest("aoti_torch_copy_") + cuda_shim_cpp_unittest("aoti_torch_cuda_guard") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp new file mode 100644 index 00000000000..7527965cdb8 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. Will be added in the future. +class AOTITorchCUDAGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + + void TearDown() override { + if (cudaGetDeviceCount(&original_device_) == cudaSuccess) { + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + } + + int original_device_ = 0; +}; + +TEST_F(AOTITorchCUDAGuardTest, CreateAndDeleteCUDAGuard) { + CUDAGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_guard(0, &guard); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + int current_device = -1; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); + + error = aoti_torch_delete_cuda_guard(guard); + EXPECT_EQ(error, Error::Ok); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAGuardNullReturnPointer) { + AOTITorchError error = aoti_torch_create_cuda_guard(0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, DeleteCUDAGuardNullHandle) { + AOTITorchError error = aoti_torch_delete_cuda_guard(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, CUDAGuardSetIndexNullHandle) { + AOTITorchError error = aoti_torch_cuda_guard_set_index(nullptr, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, CUDAGuardSetIndexInvalidDevice) { + CUDAGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_guard(0, &guard); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + error = aoti_torch_cuda_guard_set_index(guard, 999); + EXPECT_NE(error, Error::Ok); + + error = aoti_torch_delete_cuda_guard(guard); + EXPECT_EQ(error, Error::Ok); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateAndDeleteCUDAStreamGuard) { + cudaStream_t stream; + ASSERT_EQ(cudaStreamCreate(&stream), cudaSuccess); + + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_stream_guard(stream, 0, &guard); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + error = aoti_torch_delete_cuda_stream_guard(guard); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(stream), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAStreamGuardNullReturnPointer) { + cudaStream_t stream; + ASSERT_EQ(cudaStreamCreate(&stream), cudaSuccess); + + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(stream, 0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); + + ASSERT_EQ(cudaStreamDestroy(stream), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAStreamGuardNullStream) { + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(nullptr, 0, &guard); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, DeleteCUDAStreamGuardNullHandle) { + AOTITorchError error = aoti_torch_delete_cuda_stream_guard(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentCUDAStream) { + void* ret_stream = nullptr; + AOTITorchError error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(ret_stream, nullptr); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentCUDAStreamNullReturnPointer) { + AOTITorchError error = aoti_torch_get_current_cuda_stream(0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, StreamGuardWithSameDevice) { + ASSERT_EQ(cudaSetDevice(0), cudaSuccess); + + cudaStream_t stream1, stream2; + ASSERT_EQ(cudaStreamCreate(&stream1), cudaSuccess); + ASSERT_EQ(cudaStreamCreate(&stream2), cudaSuccess); + + CUDAStreamGuardHandle guard1 = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(stream1, 0, &guard1); + EXPECT_EQ(error, Error::Ok); + + void* ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream1); + + CUDAStreamGuardHandle guard2 = nullptr; + error = aoti_torch_create_cuda_stream_guard(stream2, 0, &guard2); + EXPECT_EQ(error, Error::Ok); + + ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream2); + + error = aoti_torch_delete_cuda_stream_guard(guard2); + EXPECT_EQ(error, Error::Ok); + + ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream1); + + error = aoti_torch_delete_cuda_stream_guard(guard1); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(stream1), cudaSuccess); + ASSERT_EQ(cudaStreamDestroy(stream2), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentStreamAfterSetStream) { + cudaStream_t new_stream; + ASSERT_EQ(cudaStreamCreate(&new_stream), cudaSuccess); + + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(new_stream, 0, &guard); + EXPECT_EQ(error, Error::Ok); + + void* ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), new_stream); + + error = aoti_torch_delete_cuda_stream_guard(guard); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(new_stream), cudaSuccess); +} diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index 02c3abfc83f..2d805724090 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -34,9 +34,7 @@ #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) -namespace executorch { -namespace backends { -namespace cuda { +namespace executorch::backends::cuda { // Enum for supported data types in et-cuda backend enum class SupportedDTypes : int32_t { @@ -125,6 +123,4 @@ inline AOTITorchError validate_dtype(int32_t dtype) { } } // extern "C" -} // namespace cuda -} // namespace backends -} // namespace executorch +} // namespace executorch::backends::cuda diff --git a/runtime/platform/log.h b/runtime/platform/log.h index 72ea8528442..7293fa2428d 100644 --- a/runtime/platform/log.h +++ b/runtime/platform/log.h @@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel; ##__VA_ARGS__); \ } \ } while (0) + +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \ + do { \ + if (!(_condition)) { \ + ET_LOG(Error, _format, ##__VA_ARGS__); \ + } \ + } while (0) + #else // ET_LOG_ENABLED /** @@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel; */ #define ET_LOG(_level, _format, ...) ((void)0) +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0) + #endif // ET_LOG_ENABLED