Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
# Ensure symbols are exported properly
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)

# Link against PyTorch libraries and standard libraries
# Link against ExecuTorch libraries and standard libraries
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(aoti_common)

Expand Down
9 changes: 8 additions & 1 deletion backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() {
}

// Dtype constants - these return the PyTorch dtype codes
// Currently only float32 is supported, but using robust enum-based approach
int32_t aoti_torch_dtype_float32() {
return 6; // PyTorch's float32 dtype code
}

int32_t aoti_torch_dtype_bfloat16() {
return 15; // PyTorch's bfloat16 dtype code
}

int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}

// Cleanup functions
void cleanup_tensor_metadata() {
internal::tensor_to_sizes.clear();
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
int32_t aoti_torch_device_type_cpu();
int32_t aoti_torch_layout_strided();
int32_t aoti_torch_dtype_float32();
int32_t aoti_torch_dtype_bfloat16();
int32_t aoti_torch_dtype_int64();

// Autograd mode functions
int32_t aoti_torch_grad_mode_is_enabled();
Expand Down
5 changes: 1 addition & 4 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic)

# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
target_link_libraries(
aoti_cuda
PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI CUDA functions
${TORCH_LIBRARIES}
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
)
# If you need other CUDA libraries, link them similarly:
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)
Expand Down
11 changes: 4 additions & 7 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

#define LOAD_SYMBOL(name, handle) \
do { \
Expand Down Expand Up @@ -335,14 +333,13 @@ class ET_EXPERIMENTAL CudaBackend final
}
};

} // namespace cuda
} // namespace executorch::backends::cuda

namespace executorch::backends {
namespace {
auto cls = cuda::CudaBackend();
executorch::runtime::Backend backend{"CudaBackend", &cls};
static executorch::runtime::Error success_with_compiler =
register_backend(backend);
} // namespace

} // namespace backends
} // namespace executorch
} // namespace executorch::backends
8 changes: 2 additions & 6 deletions backends/cuda/runtime/guard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#include <executorch/backends/cuda/runtime/guard.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

namespace {
// Thread-local stream storage (private to this file)
Expand Down Expand Up @@ -146,6 +144,4 @@ Result<CUDAStreamGuard> CUDAStreamGuard::create(
return stream_guard;
}

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/guard.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
#include <executorch/runtime/core/result.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

using executorch::runtime::Error;
using executorch::runtime::Result;
Expand Down Expand Up @@ -190,6 +188,4 @@ class CUDAStreamGuard {
DeviceIndex device_index_;
};

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/shims/cuda_guard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

#include <executorch/backends/cuda/runtime/shims/cuda_guard.h>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

extern "C" {

Expand Down Expand Up @@ -104,6 +102,4 @@ AOTITorchError aoti_torch_get_current_cuda_stream(

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/shims/cuda_guard.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
#include <executorch/backends/cuda/runtime/guard.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

using executorch::backends::aoti::AOTITorchError;

Expand Down Expand Up @@ -99,6 +97,4 @@ AOTITorchError aoti_torch_get_current_cuda_stream(

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
#include <unordered_set>
#include <vector>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

using executorch::aten::SizesType;
using executorch::aten::StridesType;
Expand Down Expand Up @@ -659,6 +657,4 @@ AOTITorchError aoti_torch__reinterpret_tensor(

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
#include <executorch/backends/aoti/common_shims.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

using executorch::backends::aoti::AOTITorchError;
using executorch::backends::aoti::Tensor;
Expand Down Expand Up @@ -145,6 +143,4 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
void clear_all_tensors();
} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/shims/tensor_attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

extern "C" {

Expand All @@ -31,6 +29,4 @@ int32_t aoti_torch_device_type_cuda() {

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/shims/tensor_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
#include <executorch/runtime/core/error.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

// Common using declarations for ExecutorTorch types
using executorch::runtime::Error;
Expand All @@ -35,6 +33,4 @@ int32_t aoti_torch_device_type_cuda();

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
8 changes: 2 additions & 6 deletions backends/cuda/runtime/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())

namespace executorch {
namespace backends {
namespace cuda {
namespace executorch::backends::cuda {

// Enum for supported data types in et-cuda backend
enum class SupportedDTypes : int32_t {
Expand Down Expand Up @@ -125,6 +123,4 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
}
} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
} // namespace executorch::backends::cuda
Loading