Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
switch (dtype) {
case 6: // PyTorch's float32 dtype code
return executorch::aten::ScalarType::Float;
case 15: // PyTorch's bfloat16 dtype code
return executorch::aten::ScalarType::BFloat16;
// Future support for additional dtypes can be added here
default:
ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype);
Expand Down Expand Up @@ -71,6 +73,23 @@ inline AOTITorchError validate_storage_offset(int64_t storage_offset) {
return Error::Ok;
}

// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
// Contiguous format means strides decrease from left to right:
// For NCHW: strides = [C*H*W, H*W, W, 1]
inline bool is_tensor_contiguous(
int64_t ndim,
const int64_t* sizes,
const int64_t* strides) {
int64_t expected_stride = 1;
for (int64_t i = ndim - 1; i >= 0; i--) {
if (strides[i] != expected_stride) {
return false;
}
expected_stride *= sizes[i];
}
return true;
}

} // extern "C"

} // namespace aoti
Expand Down
32 changes: 32 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.cxx_library(
name = "runtime_shims",
srcs = [
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"shims/memory.h",
"shims/tensor_attribute.h",
"shims/utils.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
# Constructor needed for backend registration.
compiler_flags = ["-Wno-global-constructors"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
"//executorch/backends/aoti:common_shims",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/core:core",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/platform:platform",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)
317 changes: 317 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
#include <executorch/backends/cuda/runtime/shims/utils.h>
#include <executorch/runtime/platform/log.h>
#include <cstdint>
#include <cstdlib> // For posix_memalign
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace executorch {
namespace backends {
namespace cuda {

using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::dtype_to_element_size;
using executorch::backends::aoti::dtype_to_scalar_type;
using executorch::backends::aoti::validate_storage_offset;

// Global storage for tensors and their metadata
std::unordered_set<std::shared_ptr<Tensor>> tensors;

// Reference counting for memory addresses
// Maps memory address to number of tensors using it
// Special value: NOT_OWN (-1) means tensor never owns the memory
constexpr int32_t NOT_OWN = -1;
std::unordered_map<void*, int32_t> memory_to_n_tensor;

extern "C" {

AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
// TODO(gasoonjia): verify given data is on the target device
(void)device_type;
(void)opaque_metadata;
(void)layout;
(void)opaque_metadata_size;

// Validate input parameters first
if (data == nullptr) {
ET_LOG(
Error,
"aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null");
return Error::InvalidArgument;
}

if (sizes_ptr == nullptr && ndim > 0) {
ET_LOG(
Error,
"aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null");
return Error::InvalidArgument;
}

if (ret_new_tensor == nullptr) {
ET_LOG(
Error,
"aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null");
return Error::InvalidArgument;
}

// Check that device_index is always 0
if (device_index != 0) {
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
return Error::InvalidArgument;
}

// Validate dtype using SupportedDTypes from utils.h
AOTITorchError dtype_error = validate_dtype(dtype);
if (dtype_error != Error::Ok) {
return dtype_error;
}

// Storage offset must be 0 since from_blob cannot handle different offsets
AOTITorchError storage_offset_error = validate_storage_offset(storage_offset);
if (storage_offset_error != Error::Ok) {
return storage_offset_error;
}

// Convert sizes to the format expected by ExecutorTorch using SizesType
std::vector<executorch::aten::SizesType> sizes =
convert_sizes_to_vector(ndim, sizes_ptr);

// Convert strides using the common helper function with StridesType
std::vector<executorch::aten::StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create ExecutorTorch tensor that wraps the existing memory
// Note: We're NOT copying the data, just wrapping it
auto tensor = executorch::extension::from_blob(
data, // existing memory (don't copy!)
sizes, // tensor dimensions
strides, // tensor strides (allows different strides)
dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType
);

if (!tensor) {
ET_LOG(Error, "Failed to create tensor from blob");
return Error::InvalidArgument;
}

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*ret_new_tensor = tensor.get();

// Check if this memory address is already being tracked
auto memory_it = memory_to_n_tensor.find(data);
if (memory_it != memory_to_n_tensor.end()) {
ET_LOG(
Error,
"Memory address %p is already being tracked by another tensor",
data);
return Error::InvalidArgument;
}

// Mark this memory as NOT_OWN since tensor created from blob never owns
// memory
memory_to_n_tensor[data] = NOT_OWN;

return Error::Ok;
}

AOTITorchError aoti_torch_empty_strided(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor) {
// Check that device_index is always 0
if (device_index != 0) {
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
return Error::InvalidArgument;
}

// This requires us to reserve CUDA memory and put it into a ETensor
void* ptr;
int64_t numel = 1;
for (int64_t i = 0; i < ndim; i++) {
numel *= sizes_ptr[i];
}

AOTITorchError dtype_error = validate_dtype(dtype);
if (dtype_error != Error::Ok) {
return dtype_error;
}

size_t element_size = dtype_to_element_size(dtype);
if (element_size == 0) {
ET_LOG(Error, "Invalid element size for dtype: %d", dtype);
return Error::InvalidArgument;
}
int64_t nbytes = numel * element_size;

if (device_type == 1) { // cuda
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes));
} else if (device_type == 0) { // cpu
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
int result = posix_memalign(&ptr, 16, nbytes);
if (result != 0) {
ET_LOG(Error, "Failed to allocate aligned CPU memory");
return Error::MemoryAllocationFailed;
}
if (ptr == nullptr) {
ET_LOG(Error, "Failed to call posix_memalign");
return Error::MemoryAllocationFailed;
}
} else {
ET_LOG(
Error,
"Need to implement empty_strided for non-CUDA non-CPU device type %d",
device_type);
return Error::NotImplemented;
}

// ETensor sizes
auto sizes = convert_sizes_to_vector(ndim, sizes_ptr);

// ETensor strides
auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// ETensor creation with dynamic shape support for edge cases
auto tensor = executorch::extension::from_blob(
ptr, sizes, strides, dtype_to_scalar_type(dtype));

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);
*ret_new_tensor = tensor.get();

// This tensor owns the memory it allocated, set reference count to 1
memory_to_n_tensor[ptr] = 1;

return Error::Ok;
}

void clear_all_tensors() {
// Use aoti_torch_delete_tensor_object to properly delete each tensor
// Note: We need to collect tensor pointers first since deletion modifies the
// set
auto old_tensors =
std::move(tensors); // tensors is now empty and no need to copy
for (const auto& tensor_shared : old_tensors) {
aoti_torch_delete_tensor_object(tensor_shared.get());
}

// tensors set should now be empty, but ensure it's cleared
tensors.clear();
}

AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
// Handle null tensor pointer
if (tensor == nullptr) {
ET_LOG(Error, "Cannot delete null tensor");
return Error::InvalidArgument;
}

// Check if tensor exists in our tracking
bool found_in_tensors = false;
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
if (it->get() == tensor) {
found_in_tensors = true;
break;
}
}

// If tensor not found in our tracking, it's invalid
if (!found_in_tensors) {
ET_LOG(Error, "Didn't find tensor %p", tensor);
return Error::InvalidArgument;
}

// Find and delete the tensor
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
if (it->get() == tensor) {
// Get the tensor before erasing
auto tensor_ptr = *it;
void* data_ptr = tensor_ptr->mutable_data_ptr();

// Find the reference count for this memory address
auto memory_it = memory_to_n_tensor.find(data_ptr);
if (memory_it != memory_to_n_tensor.end()) {
int32_t ref_count = memory_it->second;

if (ref_count == NOT_OWN) {
// Tensor never owned the memory, skip freeing
// Just remove tensor from tracking
tensors.erase(it);
return Error::Ok;
} else if (ref_count == 1) {
// Only current tensor using this memory, free it
// Determine if it's GPU memory
cudaPointerAttributes attributes{};
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&attributes, data_ptr));

if (attributes.type == cudaMemoryTypeManaged) {
// This is CUDA managed memory - free with proper synchronization
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize());
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr));
} else {
// This is CPU memory - free immediately
free(data_ptr);
data_ptr = nullptr;
}

// Remove from memory tracking
memory_to_n_tensor.erase(memory_it);
} else if (ref_count > 1) {
// Other tensors still using this memory, just decrement count
memory_to_n_tensor[data_ptr] = ref_count - 1;
}
} else {
ET_LOG(Error, "Internal error: memory not found during deletion");
return Error::Internal;
}

// Remove tensor from set (this will call the destructor if it's the last
// reference)
tensors.erase(it);
return Error::Ok;
}
}

// This should never be reached since we found it above
ET_LOG(Error, "Internal error: tensor not found after validation");
return Error::Internal;
}

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
Loading
Loading