Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ inline AOTITorchError validate_storage_offset(int64_t storage_offset) {
return Error::Ok;
}

// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
// Contiguous format means strides decrease from left to right:
// For NCHW: strides = [C*H*W, H*W, W, 1]
inline bool is_tensor_contiguous(
int64_t ndim,
const int64_t* sizes,
const int64_t* strides) {
int64_t expected_stride = 1;
for (int64_t i = ndim - 1; i >= 0; i--) {
if (strides[i] != expected_stride) {
return false;
}
expected_stride *= sizes[i];
}
return true;
}

} // extern "C"

} // namespace aoti
Expand Down
187 changes: 151 additions & 36 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,10 @@
#include <cstdint>
#include <cstdlib> // For posix_memalign
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

// CUDA error checking macro
#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \
do { \
const cudaError_t err = EXPR; \
if (err == cudaSuccess) { \
break; \
} \
ET_LOG( \
Error, \
"%s:%d CUDA error: %s", \
__FILE__, \
__LINE__, \
cudaGetErrorString(err)); \
return Error::Internal; \
} while (0)

// Kernel launch check macro
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())

namespace executorch {
namespace backends {
namespace cuda {
Expand All @@ -46,12 +27,122 @@ using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::dtype_to_element_size;
using executorch::backends::aoti::dtype_to_scalar_type;
using executorch::backends::aoti::validate_storage_offset;

// Global storage for tensors and their metadata
std::unordered_set<std::shared_ptr<Tensor>> tensors;

// Reference counting for memory addresses
// Maps memory address to number of tensors using it
// Special value: NOT_OWN (-1) means tensor never owns the memory
constexpr int32_t NOT_OWN = -1;
std::unordered_map<void*, int32_t> memory_to_n_tensor;

extern "C" {

AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
// TODO(gasoonjia): verify given data is on the target device
(void)device_type;
(void)opaque_metadata;
(void)layout;
(void)opaque_metadata_size;

// Validate input parameters first
if (data == nullptr) {
ET_LOG(
Error,
"aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null");
return Error::InvalidArgument;
}

if (sizes_ptr == nullptr && ndim > 0) {
ET_LOG(
Error,
"aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null");
return Error::InvalidArgument;
}

if (ret_new_tensor == nullptr) {
ET_LOG(
Error,
"aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null");
return Error::InvalidArgument;
}

// Check that device_index is always 0
if (device_index != 0) {
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
return Error::InvalidArgument;
}

// Validate dtype using SupportedDTypes from utils.h
AOTITorchError dtype_error = validate_dtype(dtype);
if (dtype_error != Error::Ok) {
return dtype_error;
}

// Storage offset must be 0 since from_blob cannot handle different offsets
AOTITorchError storage_offset_error = validate_storage_offset(storage_offset);
if (storage_offset_error != Error::Ok) {
return storage_offset_error;
}

// Convert sizes to the format expected by ExecutorTorch using SizesType
std::vector<executorch::aten::SizesType> sizes =
convert_sizes_to_vector(ndim, sizes_ptr);

// Convert strides using the common helper function with StridesType
std::vector<executorch::aten::StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create ExecutorTorch tensor that wraps the existing memory
// Note: We're NOT copying the data, just wrapping it
auto tensor = executorch::extension::from_blob(
data, // existing memory (don't copy!)
sizes, // tensor dimensions
strides, // tensor strides (allows different strides)
dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType
);

if (!tensor) {
ET_LOG(Error, "Failed to create tensor from blob");
return Error::InvalidArgument;
}

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*ret_new_tensor = tensor.get();

// Check if this memory address is already being tracked
auto memory_it = memory_to_n_tensor.find(data);
if (memory_it != memory_to_n_tensor.end()) {
ET_LOG(
Error,
"Memory address %p is already being tracked by another tensor",
data);
return Error::InvalidArgument;
}

// Mark this memory as NOT_OWN since tensor created from blob never owns
// memory
memory_to_n_tensor[data] = NOT_OWN;

return Error::Ok;
}

AOTITorchError aoti_torch_empty_strided(
int64_t ndim,
const int64_t* sizes_ptr,
Expand Down Expand Up @@ -120,6 +211,9 @@ AOTITorchError aoti_torch_empty_strided(
tensors.insert(tensor);
*ret_new_tensor = tensor.get();

// This tensor owns the memory it allocated, set reference count to 1
memory_to_n_tensor[ptr] = 1;

return Error::Ok;
}

Expand Down Expand Up @@ -164,26 +258,47 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
if (it->get() == tensor) {
// Get the tensor before erasing
auto tensor_ptr = *it;

void* data_ptr = tensor_ptr->mutable_data_ptr();

// Determine if it's GPU memory
cudaPointerAttributes attributes{};
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&attributes, data_ptr));

// et tensor does not own data; need to free them manually.
if (attributes.type == cudaMemoryTypeManaged) {
// This is CUDA managed memory - free with proper synchronization
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaDeviceSynchronize()); // Wait for all operations to complete
// BEFORE freeing
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr));
// Find the reference count for this memory address
auto memory_it = memory_to_n_tensor.find(data_ptr);
if (memory_it != memory_to_n_tensor.end()) {
int32_t ref_count = memory_it->second;

if (ref_count == NOT_OWN) {
// Tensor never owned the memory, skip freeing
// Just remove tensor from tracking
tensors.erase(it);
return Error::Ok;
} else if (ref_count == 1) {
// Only current tensor using this memory, free it
// Determine if it's GPU memory
cudaPointerAttributes attributes{};
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&attributes, data_ptr));

if (attributes.type == cudaMemoryTypeManaged) {
// This is CUDA managed memory - free with proper synchronization
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize());
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr));
} else {
// This is CPU memory - free immediately
free(data_ptr);
data_ptr = nullptr;
}

// Remove from memory tracking
memory_to_n_tensor.erase(memory_it);
} else if (ref_count > 1) {
// Other tensors still using this memory, just decrement count
memory_to_n_tensor[data_ptr] = ref_count - 1;
}
} else {
// This is CPU memory - free immediately
free(data_ptr);
ET_LOG(Error, "Internal error: memory not found during deletion");
return Error::Internal;
}
// Remove from set (this will call the destructor if it's the last

// Remove tensor from set (this will call the destructor if it's the last
// reference)
tensors.erase(it);
return Error::Ok;
Expand Down
39 changes: 38 additions & 1 deletion backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,44 @@ using executorch::backends::aoti::Tensor;

extern "C" {

/**
* Creates a tensor object from an existing memory blob without copying the
* data. The tensor will wrap the provided memory and will not take ownership of
* it. When the tensor is deleted, the original memory will remain valid and
* must be freed by the caller.
*
* @param data Pointer to the memory blob to wrap (must not be null)
* @param ndim Number of dimensions in the tensor
* @param sizes_ptr Pointer to array of dimension sizes (using SizesType)
* @param strides_ptr Pointer to array of strides for each dimension (using
* StridesType, can be null for contiguous)
* @param storage_offset Storage offset (must be 0 for current implementation)
* @param dtype Data type identifier (supports FLOAT32 and BFLOAT16 from
* SupportedDTypes)
* @param device_type Device type (CPU=0, CUDA=1 from SupportedDevices)
* @param device_index Device index (must be 0 for current implementation)
* @param ret_new_tensor Output parameter for the created tensor (must not be
* null)
* @param layout Tensor layout identifier (0=strided)
* @param opaque_metadata Optional metadata pointer (can be null)
* @param opaque_metadata_size Size of opaque metadata in bytes
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);

/**
* Creates an uninitialized tensor with specified dimensions, strides, and
* dtyper on either CPU or CUDA device.
Expand Down Expand Up @@ -55,7 +93,6 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);

// Function to clear all tensors from internal storage
void clear_all_tensors();

} // extern "C"

} // namespace cuda
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ def define_common_targets():
"""
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
Loading
Loading