diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index fcabb0a3f2b..2c836101c5e 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -26,7 +26,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) find_package_torch() # Common AOTI functionality - combines all AOTI common components -set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp) +set(_aoti_common_sources common_shims.cpp) add_library(aoti_common STATIC ${_aoti_common_sources}) target_include_directories( aoti_common diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_delegate_handle.h similarity index 78% rename from backends/aoti/aoti_model_container.h rename to backends/aoti/aoti_delegate_handle.h index 877f019c457..2e72fc39821 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_delegate_handle.h @@ -60,36 +60,17 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)( AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle); -// Global function pointers (will be loaded dynamically) -extern AOTInductorModelContainerCreateWithDeviceFunc - AOTInductorModelContainerCreateWithDevice; -extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete; -extern AOTInductorModelContainerGetNumInputsFunc - AOTInductorModelContainerGetNumInputs; -extern AOTInductorModelContainerGetNumOutputsFunc - AOTInductorModelContainerGetNumOutputs; -extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; - // Retrieves the name of an input tensor by index from the AOTI model container. -// Needed by Metal backend using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( AOTInductorModelContainerHandle container_handle, size_t input_idx, const char** input_name); // Retrieves the number of constants from the AOTI model container. -// Needed by Metal backend using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( AOTInductorModelContainerHandle container_handle, size_t* num_constants); -// Global function pointers (will be loaded dynamically). -// Needed by Metal backend -extern AOTInductorModelContainerGetInputNameFunc - AOTInductorModelContainerGetInputName; -extern AOTInductorModelContainerGetNumConstantsFunc - AOTInductorModelContainerGetNumConstants; - } // extern "C" // AOTI Delegate Handle structure @@ -99,6 +80,13 @@ struct AOTIDelegateHandle { AOTInductorModelContainerHandle container_handle; void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header // dependency + + // Function pointers specific to this handle's shared library + AOTInductorModelContainerCreateWithDeviceFunc create_with_device; + AOTInductorModelContainerDeleteFunc delete_container; + AOTInductorModelContainerGetNumInputsFunc get_num_inputs; + AOTInductorModelContainerGetNumOutputsFunc get_num_outputs; + AOTInductorModelContainerRunFunc run; }; } // namespace aoti diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp deleted file mode 100644 index 46a246faeb8..00000000000 --- a/backends/aoti/aoti_model_container.cpp +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -namespace executorch { -namespace backends { -namespace aoti { - -extern "C" { - -// Global function pointers for AOT Inductor model container operations -// These will be loaded dynamically from the shared library -AOTInductorModelContainerCreateWithDeviceFunc - AOTInductorModelContainerCreateWithDevice = nullptr; -AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete = nullptr; -AOTInductorModelContainerGetNumInputsFunc - AOTInductorModelContainerGetNumInputs = nullptr; -AOTInductorModelContainerGetNumOutputsFunc - AOTInductorModelContainerGetNumOutputs = nullptr; -AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; - -// Additional global function pointers for AOT Inductor model container -// operations needed by Metal backend -AOTInductorModelContainerGetInputNameFunc - AOTInductorModelContainerGetInputName = nullptr; -AOTInductorModelContainerGetNumConstantsFunc - AOTInductorModelContainerGetNumConstants = nullptr; - -} // extern "C" - -} // namespace aoti -} // namespace backends -} // namespace executorch diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index 8bf44573bb3..b7386403679 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -25,12 +25,9 @@ def define_common_targets(): # AOTI model container functionality runtime.cxx_library( - name = "model_container", - srcs = [ - "aoti_model_container.cpp", - ], + name = "delegate_handle", headers = [ - "aoti_model_container.h", + "aoti_delegate_handle.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, @@ -44,7 +41,7 @@ def define_common_targets(): ], ) - # Common AOTI functionality (combining both common_shims and model_container) + # Common AOTI functionality (combining both common_shims and delegate_handle) runtime.cxx_library( name = "aoti_common", # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) @@ -53,6 +50,6 @@ def define_common_targets(): visibility = ["@EXECUTORCH_CLIENTS"], exported_deps = [ ":common_shims", - ":model_container", + ":delegate_handle", ], ) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 805c54ff55c..3fcd25a3d1d 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -21,18 +21,18 @@ #include // Include our shim layer headers -#include +#include #include #include #include namespace executorch::backends::cuda { -#define LOAD_SYMBOL(name, handle) \ - do { \ - name = reinterpret_cast(dlsym(handle, #name)); \ - ET_CHECK_OR_RETURN_ERROR( \ - name != nullptr, AccessFailed, "Failed to load " #name); \ +#define LOAD_SYMBOL(handle, member, name, so_handle) \ + do { \ + handle->member = reinterpret_cast(dlsym(so_handle, #name)); \ + ET_CHECK_OR_RETURN_ERROR( \ + handle->member != nullptr, AccessFailed, "Failed to load " #name); \ } while (0) using namespace std; @@ -57,12 +57,31 @@ using executorch::runtime::etensor::Tensor; class ET_EXPERIMENTAL CudaBackend final : public ::executorch::runtime::BackendInterface { private: - Error register_shared_library_functions(void* so_handle) const { - LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle); - LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle); - LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle); - LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle); - LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle); + Error load_function_pointers_into_handle( + void* so_handle, + AOTIDelegateHandle* handle) const { + LOAD_SYMBOL( + handle, + create_with_device, + AOTInductorModelContainerCreateWithDevice, + so_handle); + + LOAD_SYMBOL( + handle, delete_container, AOTInductorModelContainerDelete, so_handle); + + LOAD_SYMBOL( + handle, + get_num_inputs, + AOTInductorModelContainerGetNumInputs, + so_handle); + + LOAD_SYMBOL( + handle, + get_num_outputs, + AOTInductorModelContainerGetNumOutputs, + so_handle); + + LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle); return Error::Ok; } @@ -135,19 +154,22 @@ class ET_EXPERIMENTAL CudaBackend final processed->Free(); - // Register all shared library functions - ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle)); + // Create handle and load function pointers into it + AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + handle->so_handle = so_handle; + handle->so_path = so_path.string(); + + // Load function pointers specific to this handle's shared library + ET_CHECK_OK_OR_RETURN_ERROR( + load_function_pointers_into_handle(so_handle, handle)); AOTInductorModelContainerHandle container_handle = nullptr; - ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice( - &container_handle, 1, "cuda", nullptr)); + ET_CHECK_OK_OR_RETURN_ERROR( + handle->create_with_device(&container_handle, 1, "cuda", nullptr)); ET_LOG(Info, "container_handle = %p", container_handle); - AOTIDelegateHandle* handle = new AOTIDelegateHandle(); - handle->so_handle = so_handle; - handle->so_path = so_path.string(); handle->container_handle = container_handle; // Create a CUDA stream for asynchronous execution @@ -165,20 +187,11 @@ class ET_EXPERIMENTAL CudaBackend final Span args) const override { AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; - // Need to re-register all the symbols from the so_handle hosted by this - // CudaBackend instance. The reason is that these symbols are - // static/singleton across the whole process. When we share multiple methods - // (meaning multiple so_handle) in the same process, we need to re-register - // the symbols from the so_handle that is being used in this execution. - ET_CHECK_OK_OR_RETURN_ERROR( - register_shared_library_functions(handle->so_handle)); - size_t n_inputs; - AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs); + handle->get_num_inputs(handle->container_handle, &n_inputs); size_t n_outputs; - AOTInductorModelContainerGetNumOutputs( - handle->container_handle, &n_outputs); + handle->get_num_outputs(handle->container_handle, &n_outputs); ET_CHECK_OR_RETURN_ERROR( n_inputs + n_outputs == args.size(), @@ -261,7 +274,7 @@ class ET_EXPERIMENTAL CudaBackend final gpu_outputs[i] = gpu_output_handle; } // Run AOTI container with GPU tensors - AOTIRuntimeError error = AOTInductorModelContainerRun( + AOTIRuntimeError error = handle->run( handle->container_handle, gpu_inputs.data(), // Use GPU input tensors n_inputs,