Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
switch (dtype) {
case 6: // PyTorch's float32 dtype code
return executorch::aten::ScalarType::Float;
case 15: // PyTorch's bfloat16 dtype code
return executorch::aten::ScalarType::BFloat16;
// Future support for additional dtypes can be added here
default:
ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype);
Expand Down
32 changes: 32 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.cxx_library(
name = "runtime_shims",
srcs = [
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"shims/memory.h",
"shims/tensor_attribute.h",
"shims/utils.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
# Constructor needed for backend registration.
compiler_flags = ["-Wno-global-constructors"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
"//executorch/backends/aoti:common_shims",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/core:core",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/platform:platform",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)
135 changes: 135 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
#include <executorch/backends/cuda/runtime/shims/utils.h>
#include <executorch/runtime/platform/log.h>
#include <cstdint>
#include <cstdlib> // For posix_memalign
#include <memory>
#include <unordered_set>
#include <vector>

// CUDA error checking macro
#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \
do { \
const cudaError_t err = EXPR; \
if (err == cudaSuccess) { \
break; \
} \
ET_LOG( \
Error, \
"%s:%d CUDA error: %s", \
__FILE__, \
__LINE__, \
cudaGetErrorString(err)); \
return Error::Internal; \
} while (0)

// Kernel launch check macro
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())

namespace executorch {
namespace backends {
namespace cuda {

using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::dtype_to_element_size;
using executorch::backends::aoti::dtype_to_scalar_type;

// Global storage for tensors and their metadata
std::unordered_set<std::shared_ptr<Tensor>> tensors;

extern "C" {

AOTITorchError aoti_torch_empty_strided(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor) {
// Check that device_index is always 0
if (device_index != 0) {
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
return Error::InvalidArgument;
}

// This requires us to reserve CUDA memory and put it into a ETensor
void* ptr;
int64_t numel = 1;
for (int64_t i = 0; i < ndim; i++) {
numel *= sizes_ptr[i];
}

AOTITorchError dtype_error = validate_dtype(dtype);
if (dtype_error != Error::Ok) {
return dtype_error;
}

size_t element_size = dtype_to_element_size(dtype);
if (element_size == 0) {
ET_LOG(Error, "Invalid element size for dtype: %d", dtype);
return Error::InvalidArgument;
}
int64_t nbytes = numel * element_size;

if (device_type == 1) { // cuda
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes));
} else if (device_type == 0) { // cpu
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
int result = posix_memalign(&ptr, 16, nbytes);
if (result != 0) {
ET_LOG(Error, "Failed to allocate aligned CPU memory");
return Error::MemoryAllocationFailed;
}
if (ptr == nullptr) {
ET_LOG(Error, "Failed to call posix_memalign");
return Error::MemoryAllocationFailed;
}
} else {
ET_LOG(
Error,
"Need to implement empty_strided for non-CUDA non-CPU device type %d",
device_type);
return Error::NotImplemented;
}

// ETensor sizes
auto sizes = convert_sizes_to_vector(ndim, sizes_ptr);

// ETensor strides
auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// ETensor creation with dynamic shape support for edge cases
auto tensor = executorch::extension::from_blob(
ptr, sizes, strides, dtype_to_scalar_type(dtype));

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);
*ret_new_tensor = tensor.get();

return Error::Ok;
}

// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
void clear_all_tensors() {
tensors.clear();
}

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
55 changes: 55 additions & 0 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <executorch/backends/aoti/common_shims.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {

using executorch::backends::aoti::AOTITorchError;
using executorch::backends::aoti::Tensor;

extern "C" {

/**
* Creates an uninitialized tensor with specified dimensions, strides, and
* dtyper on either CPU or CUDA device.
*
* @param ndim Number of dimensions in the tensor
* @param sizes_ptr Pointer to array of dimension sizes
* @param strides_ptr Pointer to array of strides for each dimension
* @param dtype Data type identifier (matches PyTorch scalar types)
* @param device_type Device type (0=CPU, 1=CUDA)
* @param device_index Device index (must be 0 for current implementation)
* @param ret_new_tensor Output parameter for the created tensor
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_empty_strided(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor);

// Function to clear all tensors from internal storage
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
void clear_all_tensors();

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
6 changes: 6 additions & 0 deletions backends/cuda/runtime/shims/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
30 changes: 30 additions & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")

def cuda_shim_cpp_unittest(name):
cpp_unittest(
name = "test_" + name,
srcs = [
"test_" + name + ".cpp",
],
deps = [
"//executorch/backends/aoti:common_shims",
"//executorch/backends/cuda/runtime:runtime_shims",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
"//executorch/runtime/core/exec_aten:lib",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
Loading
Loading