Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,8 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
# Ensure symbols are exported properly
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)

# Link against PyTorch libraries and standard libraries
target_link_libraries(
aoti_common
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI functions
${TORCH_LIBRARIES}
)
# Link against ExecuTorch libraries and standard libraries
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(aoti_common)

install(
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
void* so_handle;
std::string so_path;
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
};

} // namespace aoti
Expand Down
9 changes: 8 additions & 1 deletion backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() {
}

// Dtype constants - these return the PyTorch dtype codes
// Currently only float32 is supported, but using robust enum-based approach
int32_t aoti_torch_dtype_float32() {
return 6; // PyTorch's float32 dtype code
}

int32_t aoti_torch_dtype_bfloat16() {
return 15; // PyTorch's bfloat16 dtype code
}

int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}

// Cleanup functions
void cleanup_tensor_metadata() {
internal::tensor_to_sizes.clear();
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
int32_t aoti_torch_device_type_cpu();
int32_t aoti_torch_layout_strided();
int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int64();

// Autograd mode functions
int32_t aoti_torch_grad_mode_is_enabled();
Expand Down
2 changes: 1 addition & 1 deletion backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def define_common_targets():
link_whole = True,
supports_python_dlopen = True,
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
exported_deps = [
":common_shims",
":model_container",
],
Expand Down
11 changes: 5 additions & 6 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# CUDA-specific AOTI functionality
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
set(_aoti_cuda_sources
runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
runtime/shims/cuda_guard.cpp
)
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
Expand All @@ -53,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic)

# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
target_link_libraries(
aoti_cuda
PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI CUDA functions
${TORCH_LIBRARIES}
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
)
# If you need other CUDA libraries, link them similarly:
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)
Expand Down
24 changes: 24 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ runtime.cxx_library(
name = "runtime_shims",
srcs = [
"guard.cpp",
"shims/cuda_guard.cpp",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/cuda_guard.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand All @@ -32,3 +34,25 @@ runtime.cxx_library(
("cuda", None, "cuda-lazy"),
],
)

runtime.cxx_library(
name = "cuda_backend",
srcs = [
"cuda_backend.cpp",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
# Constructor needed for backend registration.
compiler_flags = ["-Wno-global-constructors"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
":runtime_shims",
"//executorch/backends/aoti:aoti_common",
"//executorch/runtime/backend:interface",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)
Loading
Loading