diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp index 03be835a0c3..46a246faeb8 100644 --- a/backends/aoti/aoti_model_container.cpp +++ b/backends/aoti/aoti_model_container.cpp @@ -25,6 +25,13 @@ AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs = nullptr; AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; +// Additional global function pointers for AOT Inductor model container +// operations needed by Metal backend +AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName = nullptr; +AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants = nullptr; + } // extern "C" } // namespace aoti diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 9b185327172..877f019c457 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -70,6 +70,26 @@ extern AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs; extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; +// Retrieves the name of an input tensor by index from the AOTI model container. +// Needed by Metal backend +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +// Retrieves the number of constants from the AOTI model container. +// Needed by Metal backend +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Global function pointers (will be loaded dynamically). +// Needed by Metal backend +extern AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName; +extern AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants; + } // extern "C" // AOTI Delegate Handle structure diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index f0c134a716c..1afd137aa26 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -176,6 +176,12 @@ int32_t aoti_torch_dtype_int64() { return 4; // PyTorch's int64 dtype code } +// Dtype utility function needed by Metal backend. +// Returns the size of the dtype in bytes. +size_t aoti_torch_dtype_element_size(int32_t dtype) { + return dtype_to_element_size(dtype); +} + // Cleanup functions void cleanup_tensor_metadata() { internal::tensor_to_sizes.clear(); diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 5f54cd1c878..b79e4c86715 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -61,6 +61,9 @@ int32_t aoti_torch_dtype_float32(); int32_t aoti_torch_dtype_bfloat16(); int32_t aoti_torch_dtype_int64(); +// Dtype utility function needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype); + // Autograd mode functions int32_t aoti_torch_grad_mode_is_enabled(); void aoti_torch_grad_mode_set_enabled(bool enabled);