From e69db18a70f85158625c48daf7780fdb8b06b4a6 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 8 Oct 2025 10:09:45 -0700 Subject: [PATCH] introduce cuda stream into runtime backend This diff introduces CUDA streams into the Executorch runtime backend. The changes include: * Adding CUDA stream support to the `cuda_backend.cpp` file * Including the `cuda_runtime.h` header file in `cuda_backend.cpp` * Adding a `void* cuda_stream` field to the `AOTInductorModelContainer` struct in `aoti_model_container.h` to store the CUDA stream * Defining a new macro `ET_CHECK_OR_LOG` in `log.h` to check a condition and log an error message if the condition is false. Differential Revision: [D84128173](https://our.internmc.facebook.com/intern/diff/D84128173/) [ghstack-poisoned] --- backends/aoti/aoti_model_container.h | 2 ++ backends/cuda/runtime/cuda_backend.cpp | 22 ++++++++++++++++++++-- runtime/platform/log.h | 22 ++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 844bd2d5a77..9b185327172 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -77,6 +77,8 @@ struct AOTIDelegateHandle { void* so_handle; std::string so_path; AOTInductorModelContainerHandle container_handle; + void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header + // dependency }; } // namespace aoti diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 08031ce6a26..b25bbef6b04 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -16,7 +17,6 @@ #include #include -#include #include #include @@ -24,6 +24,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final handle->so_handle = so_handle; handle->so_path = so_path.string(); handle->container_handle = container_handle; + + // Create a CUDA stream for asynchronous execution + cudaStream_t cuda_stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); + handle->cuda_stream = static_cast(cuda_stream); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final n_inputs, gpu_outputs.data(), // Use GPU output tensors n_outputs, - nullptr, // Pass the actual CUDA stream! + handle->cuda_stream, // Pass the actual CUDA stream nullptr); // proxy_executor_handle can remain nullptr if (error != Error::Ok) { @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final } AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + // Destroy the CUDA stream if it exists + if (handle->cuda_stream != nullptr) { + cudaStream_t cuda_stream = static_cast(handle->cuda_stream); + cudaError_t stream_err = cudaStreamDestroy(cuda_stream); + ET_CHECK_OR_LOG( + stream_err == cudaSuccess, + "Failed to destroy CUDA stream: %s", + cudaGetErrorString(stream_err)); + handle->cuda_stream = nullptr; + } + // Delete the container BEFORE closing the shared library if (handle->container_handle != nullptr) { AOTIRuntimeError delete_result = diff --git a/runtime/platform/log.h b/runtime/platform/log.h index 72ea8528442..8824e4a25cc 100644 --- a/runtime/platform/log.h +++ b/runtime/platform/log.h @@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel; ##__VA_ARGS__); \ } \ } while (0) + +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG(_condition, _format, ...) \ + do { \ + if (!(_condition)) { \ + ET_LOG(Error, _format, ##__VA_ARGS__); \ + } \ + } while (0) + #else // ET_LOG_ENABLED /** @@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel; */ #define ET_LOG(_level, _format, ...) ((void)0) +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG(_condition, _format, ...) ((void)0) + #endif // ET_LOG_ENABLED