From 64207122dbd838fa9d84a6b3f54768d45bf41fc0 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 13:29:55 -0400 Subject: [PATCH] Update [ghstack-poisoned] --- backends/aoti/aoti_model_container.cpp | 6 ++++++ backends/aoti/aoti_model_container.h | 16 ++++++++++++++++ backends/aoti/common_shims.cpp | 5 +++++ backends/aoti/common_shims.h | 3 +++ 4 files changed, 30 insertions(+) diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp index 03be835a0c3..d1764451ab6 100644 --- a/backends/aoti/aoti_model_container.cpp +++ b/backends/aoti/aoti_model_container.cpp @@ -25,6 +25,12 @@ AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs = nullptr; AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; +// Global function pointers needed by Metal backend +AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName = nullptr; +AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants = nullptr; + } // extern "C" } // namespace aoti diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 9b185327172..88d936d21ba 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -70,6 +70,22 @@ extern AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs; extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; +// Function pointer types needed by Metal backend +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Global function pointers needed by Metal backend +extern AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName; +extern AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants; + } // extern "C" // AOTI Delegate Handle structure diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abc83779443..7802444e97e 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -145,6 +145,11 @@ void cleanup_tensor_metadata() { internal::tensor_to_strides.clear(); } +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype) { + return dtype_to_element_size(dtype); +} + } // extern "C" } // namespace aoti diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 5f54cd1c878..97fcea1085c 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -68,6 +68,9 @@ void aoti_torch_grad_mode_set_enabled(bool enabled); // Cleanup functions for clearing global state void cleanup_tensor_metadata(); +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype); + } // extern "C" } // namespace aoti