From 614e079f77678df2ef7b5a7fbf56b39dd6ee393f Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Sep 2025 12:23:11 -0700 Subject: [PATCH] common library for et-aoti-driven operators (#14492) Summary: This diff introduce common functions for all aoti-driven backends under executorch like cuda and mps. It contain two major function families: container functions for holding and running aoti programs,. and common shim layers for aoti-lib. Worth to note that functions living here should be backend-agnostic. For backend-specific functions please make it live inside each backend directory. Reviewed By: larryliu0820 Differential Revision: D83003496 --- backends/aoti/CMakeLists.txt | 54 ++++ backends/aoti/README.md | 28 ++ backends/aoti/TARGETS | 3 + backends/aoti/aoti_model_container.cpp | 32 +++ backends/aoti/aoti_model_container.h | 82 ++++++ backends/aoti/common_shims.cpp | 145 ++++++++++ backends/aoti/common_shims.h | 73 +++++ backends/aoti/targets.bzl | 58 ++++ backends/aoti/tests/TARGETS | 22 ++ backends/aoti/tests/test_common_shims.cpp | 324 ++++++++++++++++++++++ backends/aoti/tests/utils.h | 74 +++++ backends/aoti/utils.h | 78 ++++++ 12 files changed, 973 insertions(+) create mode 100644 backends/aoti/CMakeLists.txt create mode 100644 backends/aoti/README.md create mode 100644 backends/aoti/TARGETS create mode 100644 backends/aoti/aoti_model_container.cpp create mode 100644 backends/aoti/aoti_model_container.h create mode 100644 backends/aoti/common_shims.cpp create mode 100644 backends/aoti/common_shims.h create mode 100644 backends/aoti/targets.bzl create mode 100644 backends/aoti/tests/TARGETS create mode 100644 backends/aoti/tests/test_common_shims.cpp create mode 100644 backends/aoti/tests/utils.h create mode 100644 backends/aoti/utils.h diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt new file mode 100644 index 00000000000..2aa8a5692ac --- /dev/null +++ b/backends/aoti/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Use ExecuTorch's standard way to find PyTorch libraries for AOTI +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +find_package_torch() + +# Common AOTI functionality - combines all AOTI common components +set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp) +add_library(aoti_common STATIC ${_aoti_common_sources}) +target_include_directories( + aoti_common + PUBLIC $ $ + # PyTorch AOTI headers from ExecuTorch's torch detection + ${TORCH_INCLUDE_DIRS} +) +target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) +# Ensure symbols are exported properly +target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) + +# Link against PyTorch libraries and standard libraries +target_link_libraries( + aoti_common + PUBLIC extension_tensor ${CMAKE_DL_LIBS} + # Link PyTorch libraries for AOTI functions + ${TORCH_LIBRARIES} +) +executorch_target_link_options_shared_lib(aoti_common) + +install( + TARGETS aoti_common + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/aoti/README.md b/backends/aoti/README.md new file mode 100644 index 00000000000..74b45a35e5d --- /dev/null +++ b/backends/aoti/README.md @@ -0,0 +1,28 @@ +# AOTI Common Library + +This directory contains **common library components** for AOTI (Ahead-of-Time Inference) driven backends in ExecutorTorch, **not a standalone backend**. + +## Purpose + +The code in this directory provides shared functionality and utilities that are used by actual AOTI-driven backends such as: + +- **CUDA backend** - Uses AOTI for GPU acceleration +- Other AOTI-powered backends + +## Components + +- **`common_shims.cpp/h`** - Common shim functions that bridge ExecuTorch tensor operations with AOTI requirements +- **`aoti_model_container.cpp/h`** - Model container functionality for AOTI models +- **`utils.h`** - Utility functions and type definitions +- **`tests/`** - Unit tests for the common functionality + +## Usage + +This library is intended to be used as a dependency by actual AOTI backend implementations. It is not a backend that can be used directly for model execution. + +For example backend implementations that use this common library, see: +- `executorch/backends/cuda/` - CUDA AOTI backend + +## Building + +The common library components are built as part of the AOTI backend build process. See the `TARGETS` file for build configurations. diff --git a/backends/aoti/TARGETS b/backends/aoti/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp new file mode 100644 index 00000000000..03be835a0c3 --- /dev/null +++ b/backends/aoti/aoti_model_container.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace aoti { + +extern "C" { + +// Global function pointers for AOT Inductor model container operations +// These will be loaded dynamically from the shared library +AOTInductorModelContainerCreateWithDeviceFunc + AOTInductorModelContainerCreateWithDevice = nullptr; +AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete = nullptr; +AOTInductorModelContainerGetNumInputsFunc + AOTInductorModelContainerGetNumInputs = nullptr; +AOTInductorModelContainerGetNumOutputsFunc + AOTInductorModelContainerGetNumOutputs = nullptr; +AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h new file mode 100644 index 00000000000..4b20aefc976 --- /dev/null +++ b/backends/aoti/aoti_model_container.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Type definitions +using AOTIRuntimeError = Error; + +// Forward declarations for AOT Inductor model container +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; +using AOTInductorStreamHandle = void*; +using AOTIProxyExecutorHandle = void*; + +// Function pointer types for AOT Inductor model container operations +using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir); + +using AOTInductorModelContainerDeleteFunc = + AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle); + +using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_inputs); + +using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_outputs); + +using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + Tensor** input_handles, // array of input Tensor*; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + Tensor** output_handles, // array for writing output Tensor*; handles + // will be stolen by the caller; the array itself + // is borrowed + size_t n_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Global function pointers (will be loaded dynamically) +extern AOTInductorModelContainerCreateWithDeviceFunc + AOTInductorModelContainerCreateWithDevice; +extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete; +extern AOTInductorModelContainerGetNumInputsFunc + AOTInductorModelContainerGetNumInputs; +extern AOTInductorModelContainerGetNumOutputsFunc + AOTInductorModelContainerGetNumOutputs; +extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; + +} // extern "C" + +// AOTI Delegate Handle structure +struct AOTIDelegateHandle { + void* so_handle; + AOTInductorModelContainerHandle container_handle; +}; + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp new file mode 100644 index 00000000000..2f9b36e3c4f --- /dev/null +++ b/backends/aoti/common_shims.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +namespace internal { +// Global storage for tensor metadata +std::unordered_map> tensor_to_sizes; +std::unordered_map> tensor_to_strides; +} // namespace internal + +extern "C" { + +// Autograd mode functions +int32_t aoti_torch_grad_mode_is_enabled() { + // No autograd ever + return false; +} + +void aoti_torch_grad_mode_set_enabled(bool enabled) { + if (enabled) { + throw std::runtime_error("Cannot enable autograd"); + } +} + +// Tensor attribute operations +AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) { + *ret_data_ptr = tensor->mutable_data_ptr(); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_storage_offset( + Tensor* tensor, + int64_t* ret_storage_offset) { + // Storage offset is always 0 in ET + *ret_storage_offset = 0; + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) { + auto it = internal::tensor_to_strides.find(tensor); + if (it == internal::tensor_to_strides.end()) { + std::vector strides(tensor->dim()); + auto tensor_strides = tensor->strides(); + for (int i = 0; i < tensor->dim(); i++) { + strides[i] = tensor_strides[i]; + } + it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first; + } + + // For 0D tensors, data() returns nullptr on empty vectors, but we need to + // return a valid pointer + if (it->second.empty()) { + static int64_t empty_strides_placeholder = 0; + *ret_strides = &empty_strides_placeholder; + } else { + *ret_strides = it->second.data(); + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) { + *ret_dtype = static_cast(tensor->scalar_type()); + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) { + auto it = internal::tensor_to_sizes.find(tensor); + if (it == internal::tensor_to_sizes.end()) { + std::vector sizes(tensor->dim()); + auto tensor_sizes = tensor->sizes(); + for (int i = 0; i < tensor->dim(); i++) { + sizes[i] = tensor_sizes[i]; + } + it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first; + } + + // For 0D tensors, data() returns nullptr on empty vectors, but we need to + // return a valid pointer + if (it->second.empty()) { + static int64_t empty_sizes_placeholder = 0; + *ret_sizes = &empty_sizes_placeholder; + } else { + *ret_sizes = it->second.data(); + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_device_index( + Tensor* tensor, + int32_t* ret_device_index) { + // Let's assume all tensors AOTI using are on CUDA:0 + *ret_device_index = 0; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) { + *ret_dim = static_cast(tensor->dim()); + return Error::Ok; +} + +// Device and layout utility functions +int32_t aoti_torch_device_type_cpu() { + // Let's say cpu is 0 for ET as well + return 0; +} + +int32_t aoti_torch_layout_strided() { + // ET only support strided layout, the return value will always be 0, a.k.a + // at::Layout::Strided; + return 0; +} + +// Dtype constants - these return the PyTorch dtype codes +// Currently only float32 is supported, but using robust enum-based approach +int32_t aoti_torch_dtype_float32() { + return 6; // PyTorch's float32 dtype code +} + +// Cleanup functions +void cleanup_tensor_metadata() { + internal::tensor_to_sizes.clear(); + internal::tensor_to_strides.clear(); +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h new file mode 100644 index 00000000000..ffcbaa11a08 --- /dev/null +++ b/backends/aoti/common_shims.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecuTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +// Global storage for tensor metadata +extern std::unordered_map> tensor_to_sizes; +extern std::unordered_map> tensor_to_strides; + +// Attribute-related operations (memory-irrelevant) +AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr); + +AOTITorchError aoti_torch_get_storage_offset( + Tensor* tensor, + int64_t* ret_storage_offset); + +AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides); + +AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype); + +AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes); + +AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); + +AOTITorchError aoti_torch_get_device_index( + Tensor* tensor, + int32_t* ret_device_index); + +AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); + +// Utility functions for device and layout information +int32_t aoti_torch_device_type_cpu(); +int32_t aoti_torch_layout_strided(); +int32_t aoti_torch_dtype_float32(); + +// Autograd mode functions +int32_t aoti_torch_grad_mode_is_enabled(); +void aoti_torch_grad_mode_set_enabled(bool enabled); + +// Cleanup functions for clearing global state +void cleanup_tensor_metadata(); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl new file mode 100644 index 00000000000..79f082e5a89 --- /dev/null +++ b/backends/aoti/targets.bzl @@ -0,0 +1,58 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + # AOTI common shims functionality + runtime.cxx_library( + name = "common_shims", + srcs = [ + "common_shims.cpp", + ], + headers = [ + "common_shims.h", + "utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], + ) + + # AOTI model container functionality + runtime.cxx_library( + name = "model_container", + srcs = [ + "aoti_model_container.cpp", + ], + headers = [ + "aoti_model_container.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/backend:interface", + "//executorch/runtime/core:core", + ], + ) + + # Common AOTI functionality (combining both common_shims and model_container) + runtime.cxx_library( + name = "aoti_common", + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + ":common_shims", + ":model_container", + ], + ) diff --git a/backends/aoti/tests/TARGETS b/backends/aoti/tests/TARGETS new file mode 100644 index 00000000000..8daa8abd4d7 --- /dev/null +++ b/backends/aoti/tests/TARGETS @@ -0,0 +1,22 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") + +oncall("executorch") + +cpp_unittest( + name = "test_common_shims", + srcs = [ + "test_common_shims.cpp", + ], + headers = [ + "utils.h", + ], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/extension/tensor:tensor", + ], +) diff --git a/backends/aoti/tests/test_common_shims.cpp b/backends/aoti/tests/test_common_shims.cpp new file mode 100644 index 00000000000..980eae96122 --- /dev/null +++ b/backends/aoti/tests/test_common_shims.cpp @@ -0,0 +1,324 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::aoti::test; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for common shims tests +class CommonShimsTest : public ::testing::Test { + protected: + void SetUp() override { + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + } + + void TearDown() override { + // Clean up metadata and free any tensor data + cleanup_tensor_metadata(); + for (auto& tensor : test_tensors_) { + free_tensor_data(tensor.get()); + } + test_tensors_.clear(); + } + + // Helper to create and track test tensors for cleanup + Tensor* create_tracked_tensor(const std::vector& sizes) { + auto tensor = create_test_tensor(sizes); + Tensor* ptr = tensor.get(); + test_tensors_.push_back(tensor); + return ptr; + } + + private: + std::vector> test_tensors_; +}; + +// Test aoti_torch_get_sizes basic functionality +TEST_F(CommonShimsTest, GetSizesBasicFunctionality) { + // Test 1D tensor + auto tensor_1d = create_tracked_tensor({5}); + int64_t* sizes_ptr; + AOTITorchError error = aoti_torch_get_sizes(tensor_1d, &sizes_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 5); + + // Test 2D tensor + auto tensor_2d = create_tracked_tensor({3, 4}); + error = aoti_torch_get_sizes(tensor_2d, &sizes_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 3); + EXPECT_EQ(sizes_ptr[1], 4); + + // Test 3D tensor + auto tensor_3d = create_tracked_tensor({2, 3, 4}); + error = aoti_torch_get_sizes(tensor_3d, &sizes_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 3); + EXPECT_EQ(sizes_ptr[2], 4); +} + +// Test aoti_torch_get_strides basic functionality +TEST_F(CommonShimsTest, GetStridesBasicFunctionality) { + // Test 1D tensor + auto tensor_1d = create_tracked_tensor({5}); + int64_t* strides_ptr; + AOTITorchError error = aoti_torch_get_strides(tensor_1d, &strides_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 1); + + // Test 2D tensor - row major: [3, 4] should have strides [4, 1] + auto tensor_2d = create_tracked_tensor({3, 4}); + error = aoti_torch_get_strides(tensor_2d, &strides_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 4); + EXPECT_EQ(strides_ptr[1], 1); + + // Test 3D tensor - row major: [2, 3, 4] should have strides [12, 4, 1] + auto tensor_3d = create_tracked_tensor({2, 3, 4}); + error = aoti_torch_get_strides(tensor_3d, &strides_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 12); + EXPECT_EQ(strides_ptr[1], 4); + EXPECT_EQ(strides_ptr[2], 1); +} + +// Test caching logic for sizes +TEST_F(CommonShimsTest, SizesCachingLogic) { + auto tensor = create_tracked_tensor({2, 3, 4}); + + // First call should cache the sizes + int64_t* sizes_ptr1; + AOTITorchError error = aoti_torch_get_sizes(tensor, &sizes_ptr1); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr1, nullptr); + + // Second call should return the same cached pointer + int64_t* sizes_ptr2; + error = aoti_torch_get_sizes(tensor, &sizes_ptr2); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(sizes_ptr1, sizes_ptr2); // Should be the exact same pointer + + // Values should still be correct + EXPECT_EQ(sizes_ptr2[0], 2); + EXPECT_EQ(sizes_ptr2[1], 3); + EXPECT_EQ(sizes_ptr2[2], 4); +} + +// Test caching logic for strides +TEST_F(CommonShimsTest, StridesCachingLogic) { + auto tensor = create_tracked_tensor({2, 3, 4}); + + // First call should cache the strides + int64_t* strides_ptr1; + AOTITorchError error = aoti_torch_get_strides(tensor, &strides_ptr1); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr1, nullptr); + + // Second call should return the same cached pointer + int64_t* strides_ptr2; + error = aoti_torch_get_strides(tensor, &strides_ptr2); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(strides_ptr1, strides_ptr2); // Should be the exact same pointer + + // Values should still be correct + EXPECT_EQ(strides_ptr2[0], 12); + EXPECT_EQ(strides_ptr2[1], 4); + EXPECT_EQ(strides_ptr2[2], 1); +} + +// Test that different tensors have different cached entries +TEST_F(CommonShimsTest, DifferentTensorsCacheSeparately) { + auto tensor1 = create_tracked_tensor({2, 3}); + auto tensor2 = create_tracked_tensor({4, 5}); + + // Get sizes for both tensors + int64_t* sizes1_ptr; + int64_t* sizes2_ptr; + + EXPECT_EQ(aoti_torch_get_sizes(tensor1, &sizes1_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor2, &sizes2_ptr), Error::Ok); + + // Pointers should be different (different cache entries) + EXPECT_NE(sizes1_ptr, sizes2_ptr); + + // Values should be correct + EXPECT_EQ(sizes1_ptr[0], 2); + EXPECT_EQ(sizes1_ptr[1], 3); + EXPECT_EQ(sizes2_ptr[0], 4); + EXPECT_EQ(sizes2_ptr[1], 5); + + // Test strides as well + int64_t* strides1_ptr; + int64_t* strides2_ptr; + + EXPECT_EQ(aoti_torch_get_strides(tensor1, &strides1_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor2, &strides2_ptr), Error::Ok); + + // Pointers should be different (different cache entries) + EXPECT_NE(strides1_ptr, strides2_ptr); + + // Values should be correct + EXPECT_EQ(strides1_ptr[0], 3); + EXPECT_EQ(strides1_ptr[1], 1); + EXPECT_EQ(strides2_ptr[0], 5); + EXPECT_EQ(strides2_ptr[1], 1); +} + +// Test cache persistence across multiple calls +TEST_F(CommonShimsTest, CachePersistence) { + auto tensor = create_tracked_tensor({3, 4, 5}); + + // Multiple calls to sizes should all return the same pointer + int64_t* sizes_ptr1; + int64_t* sizes_ptr2; + int64_t* sizes_ptr3; + + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr3), Error::Ok); + + EXPECT_EQ(sizes_ptr1, sizes_ptr2); + EXPECT_EQ(sizes_ptr2, sizes_ptr3); + + // Multiple calls to strides should all return the same pointer + int64_t* strides_ptr1; + int64_t* strides_ptr2; + int64_t* strides_ptr3; + + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr3), Error::Ok); + + EXPECT_EQ(strides_ptr1, strides_ptr2); + EXPECT_EQ(strides_ptr2, strides_ptr3); +} + +// Test 0D tensor (scalar) +TEST_F(CommonShimsTest, ScalarTensor) { + auto tensor_0d = create_tracked_tensor({}); + + // Test sizes for 0D tensor + int64_t* sizes_ptr; + AOTITorchError error = aoti_torch_get_sizes(tensor_0d, &sizes_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + + // Test strides for 0D tensor + int64_t* strides_ptr; + error = aoti_torch_get_strides(tensor_0d, &strides_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + + // Cache should work for 0D tensors too + int64_t* sizes_ptr2; + error = aoti_torch_get_sizes(tensor_0d, &sizes_ptr2); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(sizes_ptr, sizes_ptr2); +} + +// Test large tensor dimensions +TEST_F(CommonShimsTest, LargeTensorDimensions) { + auto tensor = create_tracked_tensor({100, 200, 300, 400}); + + // Test sizes + int64_t* sizes_ptr; + AOTITorchError error = aoti_torch_get_sizes(tensor, &sizes_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 100); + EXPECT_EQ(sizes_ptr[1], 200); + EXPECT_EQ(sizes_ptr[2], 300); + EXPECT_EQ(sizes_ptr[3], 400); + + // Test strides - expected: [24000000, 120000, 400, 1] + int64_t* strides_ptr; + error = aoti_torch_get_strides(tensor, &strides_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 24000000); + EXPECT_EQ(strides_ptr[1], 120000); + EXPECT_EQ(strides_ptr[2], 400); + EXPECT_EQ(strides_ptr[3], 1); +} + +// Test that cleanup_tensor_metadata clears the cache +TEST_F(CommonShimsTest, CleanupFunctionality) { + auto tensor = create_tracked_tensor({2, 3}); + + // Cache some data + int64_t* sizes_ptr1; + int64_t* strides_ptr1; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + + // Clear the cache + cleanup_tensor_metadata(); + + // Getting sizes/strides again should create new cache entries + // (We can't directly test if the pointers are different since that would be + // implementation-dependent, but we can at least verify the functions still + // work) + int64_t* sizes_ptr2; + int64_t* strides_ptr2; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + + // Values should still be correct + EXPECT_EQ(sizes_ptr2[0], 2); + EXPECT_EQ(sizes_ptr2[1], 3); + EXPECT_EQ(strides_ptr2[0], 3); + EXPECT_EQ(strides_ptr2[1], 1); +} + +// Test mixed operations to ensure caches are independent +TEST_F(CommonShimsTest, IndependentCaches) { + auto tensor = create_tracked_tensor({2, 3, 4}); + + // Get sizes first + int64_t* sizes_ptr1; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + + // Get strides + int64_t* strides_ptr1; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + + // Get sizes again - should be cached + int64_t* sizes_ptr2; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(sizes_ptr1, sizes_ptr2); + + // Get strides again - should be cached + int64_t* strides_ptr2; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + EXPECT_EQ(strides_ptr1, strides_ptr2); + + // Sizes and strides pointers should be different (different caches) + EXPECT_NE(sizes_ptr1, strides_ptr1); +} diff --git a/backends/aoti/tests/utils.h b/backends/aoti/tests/utils.h new file mode 100644 index 00000000000..1f26f7e2d51 --- /dev/null +++ b/backends/aoti/tests/utils.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { +namespace test { + +// Use the same type aliases as in common_shims.h +using executorch::runtime::etensor::Tensor; + +/** + * Creates a test tensor with the specified shape and scalar type + */ +inline std::shared_ptr create_test_tensor( + const std::vector& sizes, + exec_aten::ScalarType dtype = exec_aten::ScalarType::Float) { + // Calculate total number of elements + int64_t total_elements = 1; + for (int64_t size : sizes) { + total_elements *= size; + } + + // Calculate strides (row-major layout) + std::vector strides(sizes.size()); + if (sizes.size() > 0) { + strides[sizes.size() - 1] = 1; + for (int i = sizes.size() - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + // Allocate data buffer + size_t dtype_size = exec_aten::elementSize(dtype); + void* data = malloc(total_elements * dtype_size); + + // Convert sizes and strides to the required type + std::vector sizes_converted( + sizes.begin(), sizes.end()); + std::vector strides_converted( + strides.begin(), strides.end()); + + // Create the tensor with the correct argument types and count + auto tensor = executorch::extension::from_blob( + data, sizes_converted, strides_converted, dtype); + + return tensor; +} + +/** + * Helper to clean up tensor data that was allocated with malloc + */ +inline void free_tensor_data(Tensor* tensor) { + if (tensor && tensor->mutable_data_ptr()) { + free(tensor->mutable_data_ptr()); + } +} + +} // namespace test +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h new file mode 100644 index 00000000000..82d30cdb4ef --- /dev/null +++ b/backends/aoti/utils.h @@ -0,0 +1,78 @@ + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecuTorch types +using executorch::runtime::Error; + +extern "C" { + +// Common AOTI type aliases +using AOTITorchError = Error; + +// Map int32_t dtype to ExecuTorch ScalarType (robust version of hardcoded +// ScalarType::Float) +inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { + // Convert based on known PyTorch dtype codes (without CUDA-specific + // dependency) + switch (dtype) { + case 6: // PyTorch's float32 dtype code + return executorch::aten::ScalarType::Float; + // Future support for additional dtypes can be added here + default: + ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype); + return executorch::aten::ScalarType::Undefined; + } +} + +// Map int32_t dtype to number of bytes per element (reusing ExecuTorch's +// elementSize function) +inline size_t dtype_to_element_size(int32_t dtype) { + // First convert int32_t dtype to ExecuTorch ScalarType, then use existing + // elementSize function + executorch::aten::ScalarType scalar_type = dtype_to_scalar_type(dtype); + if (scalar_type == executorch::aten::ScalarType::Undefined) { + ET_LOG(Error, "Unsupported dtype: %d for element size calculation", dtype); + return 0; // Return 0 to indicate error + } + + // Reuse ExecuTorch's existing elementSize function from scalar_type_util.h + return executorch::runtime::elementSize(scalar_type); +} + +// Storage offset validation utility function +inline AOTITorchError validate_storage_offset(int64_t storage_offset) { + // Storage offset must always be 0 + if (storage_offset != 0) { + ET_LOG( + Error, + "Storage offset must be 0. Got storage_offset: %ld", + storage_offset); + return Error::InvalidArgument; + } + return Error::Ok; +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch