Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Build AOTI backend for runtime.
#
# ### Editing this file ###
#
# This file should be formatted with
# ~~~
# cmake-format -i CMakeLists.txt
# ~~~
# It should also be cmake-lint clean.
#
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

# Use ExecuTorch's standard way to find PyTorch libraries for AOTI
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# Common AOTI functionality - combines all AOTI common components
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
add_library(aoti_common STATIC ${_aoti_common_sources})
target_include_directories(
aoti_common
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
# PyTorch AOTI headers from ExecuTorch's torch detection
${TORCH_INCLUDE_DIRS}
)
target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
# Ensure symbols are exported properly
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)

# Link against PyTorch libraries and standard libraries
target_link_libraries(
aoti_common
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI functions
${TORCH_LIBRARIES}
)
executorch_target_link_options_shared_lib(aoti_common)

install(
TARGETS aoti_common
EXPORT ExecuTorchTargets
DESTINATION lib
)
28 changes: 28 additions & 0 deletions backends/aoti/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# AOTI Common Library

This directory contains **common library components** for AOTI (Ahead-of-Time Inference) driven backends in ExecutorTorch, **not a standalone backend**.

## Purpose

The code in this directory provides shared functionality and utilities that are used by actual AOTI-driven backends such as:

- **CUDA backend** - Uses AOTI for GPU acceleration
- Other AOTI-powered backends

## Components

- **`common_shims.cpp/h`** - Common shim functions that bridge ExecuTorch tensor operations with AOTI requirements
- **`aoti_model_container.cpp/h`** - Model container functionality for AOTI models
- **`utils.h`** - Utility functions and type definitions
- **`tests/`** - Unit tests for the common functionality

## Usage

This library is intended to be used as a dependency by actual AOTI backend implementations. It is not a backend that can be used directly for model execution.

For example backend implementations that use this common library, see:
- `executorch/backends/cuda/` - CUDA AOTI backend

## Building

The common library components are built as part of the AOTI backend build process. See the `TARGETS` file for build configurations.
3 changes: 3 additions & 0 deletions backends/aoti/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load("targets.bzl", "define_common_targets")

define_common_targets()
32 changes: 32 additions & 0 deletions backends/aoti/aoti_model_container.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/aoti/aoti_model_container.h>

namespace executorch {
namespace backends {
namespace aoti {

extern "C" {

// Global function pointers for AOT Inductor model container operations
// These will be loaded dynamically from the shared library
AOTInductorModelContainerCreateWithDeviceFunc
AOTInductorModelContainerCreateWithDevice = nullptr;
AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete = nullptr;
AOTInductorModelContainerGetNumInputsFunc
AOTInductorModelContainerGetNumInputs = nullptr;
AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs = nullptr;
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;

} // extern "C"

} // namespace aoti
} // namespace backends
} // namespace executorch
82 changes: 82 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>

namespace executorch {
namespace backends {
namespace aoti {

using executorch::runtime::Error;
using executorch::runtime::etensor::Tensor;

extern "C" {

// Type definitions
using AOTIRuntimeError = Error;

// Forward declarations for AOT Inductor model container
struct AOTInductorModelContainerOpaque;
using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
using AOTInductorStreamHandle = void*;
using AOTIProxyExecutorHandle = void*;

// Function pointer types for AOT Inductor model container operations
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle* container_handle,
size_t num_models,
const char* device_str,
const char* cubin_dir);

using AOTInductorModelContainerDeleteFunc =
AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle);

using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_inputs);

using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_outputs);

using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
Tensor** input_handles, // array of input Tensor*; handles
// are stolen; the array itself is borrowed
size_t num_inputs,
Tensor** output_handles, // array for writing output Tensor*; handles
// will be stolen by the caller; the array itself
// is borrowed
size_t n_outputs,
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);

// Global function pointers (will be loaded dynamically)
extern AOTInductorModelContainerCreateWithDeviceFunc
AOTInductorModelContainerCreateWithDevice;
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
extern AOTInductorModelContainerGetNumInputsFunc
AOTInductorModelContainerGetNumInputs;
extern AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs;
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;

} // extern "C"

// AOTI Delegate Handle structure
struct AOTIDelegateHandle {
void* so_handle;
AOTInductorModelContainerHandle container_handle;
};

} // namespace aoti
} // namespace backends
} // namespace executorch
145 changes: 145 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/aoti/common_shims.h>
#include <executorch/runtime/platform/log.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace aoti {

namespace internal {
// Global storage for tensor metadata
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
} // namespace internal

extern "C" {

// Autograd mode functions
int32_t aoti_torch_grad_mode_is_enabled() {
// No autograd ever
return false;
}

void aoti_torch_grad_mode_set_enabled(bool enabled) {
if (enabled) {
throw std::runtime_error("Cannot enable autograd");
}
}

// Tensor attribute operations
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
*ret_data_ptr = tensor->mutable_data_ptr();
return Error::Ok;
}

AOTITorchError aoti_torch_get_storage_offset(
Tensor* tensor,
int64_t* ret_storage_offset) {
// Storage offset is always 0 in ET
*ret_storage_offset = 0;

return Error::Ok;
}

AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
auto it = internal::tensor_to_strides.find(tensor);
if (it == internal::tensor_to_strides.end()) {
std::vector<int64_t> strides(tensor->dim());
auto tensor_strides = tensor->strides();
for (int i = 0; i < tensor->dim(); i++) {
strides[i] = tensor_strides[i];
}
it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first;
}

// For 0D tensors, data() returns nullptr on empty vectors, but we need to
// return a valid pointer
if (it->second.empty()) {
static int64_t empty_strides_placeholder = 0;
*ret_strides = &empty_strides_placeholder;
} else {
*ret_strides = it->second.data();
}

return Error::Ok;
}

AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
*ret_dtype = static_cast<int32_t>(tensor->scalar_type());

return Error::Ok;
}

AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
auto it = internal::tensor_to_sizes.find(tensor);
if (it == internal::tensor_to_sizes.end()) {
std::vector<int64_t> sizes(tensor->dim());
auto tensor_sizes = tensor->sizes();
for (int i = 0; i < tensor->dim(); i++) {
sizes[i] = tensor_sizes[i];
}
it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first;
}

// For 0D tensors, data() returns nullptr on empty vectors, but we need to
// return a valid pointer
if (it->second.empty()) {
static int64_t empty_sizes_placeholder = 0;
*ret_sizes = &empty_sizes_placeholder;
} else {
*ret_sizes = it->second.data();
}

return Error::Ok;
}

AOTITorchError aoti_torch_get_device_index(
Tensor* tensor,
int32_t* ret_device_index) {
// Let's assume all tensors AOTI using are on CUDA:0
*ret_device_index = 0;
return Error::Ok;
}

AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
*ret_dim = static_cast<int64_t>(tensor->dim());
return Error::Ok;
}

// Device and layout utility functions
int32_t aoti_torch_device_type_cpu() {
// Let's say cpu is 0 for ET as well
return 0;
}

int32_t aoti_torch_layout_strided() {
// ET only support strided layout, the return value will always be 0, a.k.a
// at::Layout::Strided;
return 0;
}

// Dtype constants - these return the PyTorch dtype codes
// Currently only float32 is supported, but using robust enum-based approach
int32_t aoti_torch_dtype_float32() {
return 6; // PyTorch's float32 dtype code
}

// Cleanup functions
void cleanup_tensor_metadata() {
internal::tensor_to_sizes.clear();
internal::tensor_to_strides.clear();
}

} // extern "C"

} // namespace aoti
} // namespace backends
} // namespace executorch
Loading
Loading