diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 844bd2d5a77..9b185327172 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -77,6 +77,8 @@ struct AOTIDelegateHandle { void* so_handle; std::string so_path; AOTInductorModelContainerHandle container_handle; + void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header + // dependency }; } // namespace aoti diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 08031ce6a26..5f113b1ce68 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -16,7 +17,6 @@ #include #include -#include #include #include @@ -24,6 +24,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final handle->so_handle = so_handle; handle->so_path = so_path.string(); handle->container_handle = container_handle; + + // Create a CUDA stream for asynchronous execution + cudaStream_t cuda_stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); + handle->cuda_stream = static_cast(cuda_stream); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final n_inputs, gpu_outputs.data(), // Use GPU output tensors n_outputs, - nullptr, // Pass the actual CUDA stream! + handle->cuda_stream, // Pass the actual CUDA stream nullptr); // proxy_executor_handle can remain nullptr if (error != Error::Ok) { @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final } AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + // Destroy the CUDA stream if it exists + if (handle->cuda_stream != nullptr) { + cudaStream_t cuda_stream = static_cast(handle->cuda_stream); + cudaError_t stream_err = cudaStreamDestroy(cuda_stream); + ET_CHECK_OR_LOG_ERROR( + stream_err == cudaSuccess, + "Failed to destroy CUDA stream: %s", + cudaGetErrorString(stream_err)); + handle->cuda_stream = nullptr; + } + // Delete the container BEFORE closing the shared library if (handle->container_handle != nullptr) { AOTIRuntimeError delete_result = diff --git a/runtime/platform/log.h b/runtime/platform/log.h index 72ea8528442..7293fa2428d 100644 --- a/runtime/platform/log.h +++ b/runtime/platform/log.h @@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel; ##__VA_ARGS__); \ } \ } while (0) + +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \ + do { \ + if (!(_condition)) { \ + ET_LOG(Error, _format, ##__VA_ARGS__); \ + } \ + } while (0) + #else // ET_LOG_ENABLED /** @@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel; */ #define ET_LOG(_level, _format, ...) ((void)0) +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0) + #endif // ET_LOG_ENABLED