From dad276e3696a612c3dd77169d6d5f773f5c3eca4 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Sun, 28 Sep 2025 13:46:55 -0700 Subject: [PATCH 1/5] tensor empty strided (#14549) Summary: this diff introduce aoti_tensor_empty_strided to et cuda backend, which will be one of the main functions to create empty tensor using the given stride. Reviewed By: larryliu0820 Differential Revision: D83094606 --- backends/aoti/utils.h | 2 + backends/cuda/runtime/TARGETS | 32 + backends/cuda/runtime/shims/memory.cpp | 135 ++++ backends/cuda/runtime/shims/memory.h | 55 ++ backends/cuda/runtime/shims/tests/TARGETS | 6 + backends/cuda/runtime/shims/tests/targets.bzl | 30 + .../tests/test_aoti_torch_empty_strided.cpp | 588 ++++++++++++++++++ backends/cuda/runtime/shims/utils.h | 109 ++++ 8 files changed, 957 insertions(+) create mode 100644 backends/cuda/runtime/TARGETS create mode 100644 backends/cuda/runtime/shims/memory.cpp create mode 100644 backends/cuda/runtime/shims/memory.h create mode 100644 backends/cuda/runtime/shims/tests/TARGETS create mode 100644 backends/cuda/runtime/shims/tests/targets.bzl create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp create mode 100644 backends/cuda/runtime/shims/utils.h diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 82d30cdb4ef..22734935df2 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -36,6 +36,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { switch (dtype) { case 6: // PyTorch's float32 dtype code return executorch::aten::ScalarType::Float; + case 15: // PyTorch's bfloat16 dtype code + return executorch::aten::ScalarType::BFloat16; // Future support for additional dtypes can be added here default: ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype); diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS new file mode 100644 index 00000000000..1aa38760e5a --- /dev/null +++ b/backends/cuda/runtime/TARGETS @@ -0,0 +1,32 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.cxx_library( + name = "runtime_shims", + srcs = [ + "shims/memory.cpp", + "shims/tensor_attribute.cpp", + ], + headers = [ + "shims/memory.h", + "shims/tensor_attribute.h", + "shims/utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp new file mode 100644 index 00000000000..99d936e32ca --- /dev/null +++ b/backends/cuda/runtime/shims/memory.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include // For posix_memalign +#include +#include +#include + +// CUDA error checking macro +#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + return Error::Internal; \ + } while (0) + +// Kernel launch check macro +#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::aten::SizesType; +using executorch::aten::StridesType; +using executorch::backends::aoti::dtype_to_element_size; +using executorch::backends::aoti::dtype_to_scalar_type; + +// Global storage for tensors and their metadata +std::unordered_set> tensors; + +extern "C" { + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor) { + // Check that device_index is always 0 + if (device_index != 0) { + ET_LOG(Error, "device_index must be 0, got: %d", device_index); + return Error::InvalidArgument; + } + + // This requires us to reserve CUDA memory and put it into a ETensor + void* ptr; + int64_t numel = 1; + for (int64_t i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + AOTITorchError dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + size_t element_size = dtype_to_element_size(dtype); + if (element_size == 0) { + ET_LOG(Error, "Invalid element size for dtype: %d", dtype); + return Error::InvalidArgument; + } + int64_t nbytes = numel * element_size; + + if (device_type == 1) { // cuda + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes)); + } else if (device_type == 0) { // cpu + // Ensure 16-byte alignment for CPU memory to match CUDA requirements + int result = posix_memalign(&ptr, 16, nbytes); + if (result != 0) { + ET_LOG(Error, "Failed to allocate aligned CPU memory"); + return Error::MemoryAllocationFailed; + } + if (ptr == nullptr) { + ET_LOG(Error, "Failed to call posix_memalign"); + return Error::MemoryAllocationFailed; + } + } else { + ET_LOG( + Error, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + return Error::NotImplemented; + } + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // ETensor creation with dynamic shape support for edge cases + auto tensor = executorch::extension::from_blob( + ptr, sizes, strides, dtype_to_scalar_type(dtype)); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + + return Error::Ok; +} + +// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors +void clear_all_tensors() { + tensors.clear(); +} + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h new file mode 100644 index 00000000000..2fdfdd8a72c --- /dev/null +++ b/backends/cuda/runtime/shims/memory.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +extern "C" { + +/** + * Creates an uninitialized tensor with specified dimensions, strides, and + * dtyper on either CPU or CUDA device. + * + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes + * @param strides_ptr Pointer to array of strides for each dimension + * @param dtype Data type identifier (matches PyTorch scalar types) + * @param device_type Device type (0=CPU, 1=CUDA) + * @param device_index Device index (must be 0 for current implementation) + * @param ret_new_tensor Output parameter for the created tensor + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor); + +// Function to clear all tensors from internal storage +// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors +void clear_all_tensors(); + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/tests/TARGETS b/backends/cuda/runtime/shims/tests/TARGETS new file mode 100644 index 00000000000..9ff3e83a8bd --- /dev/null +++ b/backends/cuda/runtime/shims/tests/TARGETS @@ -0,0 +1,6 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl new file mode 100644 index 00000000000..5737bdb00ab --- /dev/null +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -0,0 +1,30 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") + +def cuda_shim_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/backends/cuda/runtime:runtime_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + cuda_shim_cpp_unittest("aoti_torch_empty_strided") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp new file mode 100644 index 00000000000..8e6998f457c --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp @@ -0,0 +1,588 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_empty_strided tests +class AOTITorchEmptyStridedTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_tracked_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test aoti_torch_empty_strided basic functionality +TEST_F(AOTITorchEmptyStridedTest, BasicFunctionality) { + // Test 1D tensor + std::vector sizes_1d = {5}; + Tensor* tensor_1d; + AOTITorchError error = aoti_torch_empty_strided( + sizes_1d.size(), + sizes_1d.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_1d); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_1d, nullptr); + + // CRITICAL: Verify the tensor is actually float32 + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_1d, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::FLOAT32)) + << "Expected float32 dtype (" + << static_cast(SupportedDTypes::FLOAT32) << "), got " + << actual_dtype; + + // Verify element size (float32 should be 4 bytes per element) + size_t element_size = tensor_1d->element_size(); + EXPECT_EQ(element_size, 4) + << "Expected float32 element size to be 4 bytes, got " << element_size; + + // Verify total number of elements and memory usage + int64_t expected_numel = 5; // 5 elements + EXPECT_EQ(tensor_1d->numel(), expected_numel) + << "Expected " << expected_numel << " elements, got " + << tensor_1d->numel(); + + // Verify total memory size (numel * element_size) + size_t expected_memory_size = expected_numel * 4; // 5 * 4 = 20 bytes + size_t actual_memory_size = tensor_1d->numel() * tensor_1d->element_size(); + EXPECT_EQ(actual_memory_size, expected_memory_size) + << "Expected " << expected_memory_size << " bytes, got " + << actual_memory_size; + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + + // Test 2D tensor with explicit strides + std::vector sizes_2d = {3, 4}; + std::vector strides_2d = {4, 1}; + Tensor* tensor_2d; + error = aoti_torch_empty_strided( + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_2d); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_2d, nullptr); + + // Verify 2D tensor is also float32 + int32_t dtype_2d; + EXPECT_EQ(aoti_torch_get_dtype(tensor_2d, &dtype_2d), Error::Ok); + EXPECT_EQ(dtype_2d, static_cast(SupportedDTypes::FLOAT32)) + << "Expected float32 dtype (" + << static_cast(SupportedDTypes::FLOAT32) << "), got " + << dtype_2d; + + // Verify element size for 2D tensor + EXPECT_EQ(tensor_2d->element_size(), 4); + + // Check tensor properties + EXPECT_EQ(tensor_2d->dim(), 2); + EXPECT_EQ(tensor_2d->size(0), 3); + EXPECT_EQ(tensor_2d->size(1), 4); + + // Verify memory size for 2D tensor + int64_t expected_numel_2d = 3 * 4; // 12 elements + size_t expected_memory_2d = expected_numel_2d * 4; // 12 * 4 = 48 bytes + EXPECT_EQ(tensor_2d->numel() * tensor_2d->element_size(), expected_memory_2d); +} + +// Test aoti_torch_empty_strided with CPU device +TEST_F(AOTITorchEmptyStridedTest, CPUDevice) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); +} + +// Test aoti_torch_empty_strided with invalid dtype +TEST_F(AOTITorchEmptyStridedTest, InvalidDtype) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 999, // invalid dtype + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test aoti_torch_empty_strided with unsupported device +TEST_F(AOTITorchEmptyStridedTest, UnsupportedDevice) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 2, // unsupported device type + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::NotImplemented); +} + +// Test aoti_torch_empty_strided with zero-sized tensor +TEST_F(AOTITorchEmptyStridedTest, ZeroSized) { + std::vector sizes = {0, 5}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); +} + +// Test aoti_torch_empty_strided scalar tensor (0D) +TEST_F(AOTITorchEmptyStridedTest, Scalar) { + std::vector sizes = {}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 0); +} + +// Test aoti_torch_empty_strided with large tensor +TEST_F(AOTITorchEmptyStridedTest, LargeTensor) { + std::vector sizes = {100, 200, 50}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 100); + EXPECT_EQ(tensor->size(1), 200); + EXPECT_EQ(tensor->size(2), 50); +} + +// Test error handling with memory allocation failures +TEST_F(AOTITorchEmptyStridedTest, MemoryAllocationStress) { + // Try to create a very large tensor that might cause allocation failure + // (This test may pass or fail depending on available memory) + std::vector huge_sizes = {10000, 10000, 100}; // ~38GB for float32 + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + huge_sizes.size(), + huge_sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + // Either succeed or fail with memory allocation error + if (error == Error::Ok) { + EXPECT_NE(tensor, nullptr); + } else { + EXPECT_EQ(error, Error::MemoryAllocationFailed); + } +} + +// Test aoti_torch_empty_strided with bfloat16 dtype +TEST_F(AOTITorchEmptyStridedTest, BFloat16Tensor) { + // Test creating bfloat16 tensor on CUDA + std::vector sizes = {2, 3, 4}; + Tensor* tensor_bf16; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16, nullptr); + + // CRITICAL: Verify the tensor is actually bfloat16 + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << actual_dtype; + + // Verify element size (bfloat16 should be 2 bytes per element) + size_t element_size = tensor_bf16->element_size(); + EXPECT_EQ(element_size, 2) + << "Expected bfloat16 element size to be 2 bytes, got " << element_size; + + // Verify total number of elements and memory usage + int64_t expected_numel = 2 * 3 * 4; // 24 elements + EXPECT_EQ(tensor_bf16->numel(), expected_numel) + << "Expected " << expected_numel << " elements, got " + << tensor_bf16->numel(); + + // Verify total memory size (numel * element_size) + size_t expected_memory_size = expected_numel * 2; // 24 * 2 = 48 bytes + size_t actual_memory_size = + tensor_bf16->numel() * tensor_bf16->element_size(); + EXPECT_EQ(actual_memory_size, expected_memory_size) + << "Expected " << expected_memory_size << " bytes, got " + << actual_memory_size; + + // Check tensor properties + EXPECT_EQ(tensor_bf16->dim(), 3); + EXPECT_EQ(tensor_bf16->size(0), 2); + EXPECT_EQ(tensor_bf16->size(1), 3); + EXPECT_EQ(tensor_bf16->size(2), 4); + + // Verify we can get tensor metadata + int64_t* sizes_ptr; + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_sizes(tensor_bf16, &sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor_bf16, &strides_ptr), Error::Ok); + + // Check sizes match + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 3); + EXPECT_EQ(sizes_ptr[2], 4); + + // Check that strides are computed correctly (row-major order) + EXPECT_EQ(strides_ptr[0], 12); // 3 * 4 + EXPECT_EQ(strides_ptr[1], 4); // 4 + EXPECT_EQ(strides_ptr[2], 1); // 1 + + // Test bfloat16 tensor with custom strides + std::vector sizes_2d = {3, 2}; + std::vector strides_2d = {2, 1}; // Row-major strides + Tensor* tensor_bf16_custom; + error = aoti_torch_empty_strided( + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16_custom); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16_custom, nullptr); + + // Verify custom stride tensor is also bfloat16 + int32_t custom_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_custom, &custom_dtype), Error::Ok); + EXPECT_EQ(custom_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << custom_dtype; + + // Verify element size for custom stride tensor + EXPECT_EQ(tensor_bf16_custom->element_size(), 2); + + // Check tensor properties + EXPECT_EQ(tensor_bf16_custom->dim(), 2); + EXPECT_EQ(tensor_bf16_custom->size(0), 3); + EXPECT_EQ(tensor_bf16_custom->size(1), 2); + + // Verify memory size for custom stride tensor + int64_t custom_expected_numel = 3 * 2; // 6 elements + size_t custom_expected_memory = custom_expected_numel * 2; // 6 * 2 = 12 bytes + EXPECT_EQ( + tensor_bf16_custom->numel() * tensor_bf16_custom->element_size(), + custom_expected_memory); + + // Check custom strides + int64_t* custom_strides_ptr; + EXPECT_EQ( + aoti_torch_get_strides(tensor_bf16_custom, &custom_strides_ptr), + Error::Ok); + EXPECT_EQ(custom_strides_ptr[0], 2); + EXPECT_EQ(custom_strides_ptr[1], 1); + + // Test bfloat16 scalar tensor (0D) + std::vector scalar_sizes = {}; + Tensor* tensor_bf16_scalar; + error = aoti_torch_empty_strided( + scalar_sizes.size(), + scalar_sizes.data(), + nullptr, + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16_scalar); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16_scalar, nullptr); + EXPECT_EQ(tensor_bf16_scalar->dim(), 0); + + // Verify scalar tensor is also bfloat16 + int32_t scalar_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_scalar, &scalar_dtype), Error::Ok); + EXPECT_EQ(scalar_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << scalar_dtype; + + // Verify scalar tensor properties + EXPECT_EQ(tensor_bf16_scalar->element_size(), 2); + EXPECT_EQ(tensor_bf16_scalar->numel(), 1); // Scalar tensor has 1 element + EXPECT_EQ( + tensor_bf16_scalar->numel() * tensor_bf16_scalar->element_size(), + 2); // 1 * 2 = 2 bytes +} + +// Test custom strides functionality +TEST_F(AOTITorchEmptyStridedTest, CustomStrides) { + // Create tensor with valid custom strides (contiguous layout) + std::vector sizes = {2, 3}; + std::vector strides = {3, 1}; // Standard row-major strides + + Tensor* tensor = create_tracked_tensor(sizes, strides); + EXPECT_NE(tensor, nullptr); + + // Verify the tensor was created correctly + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Check strides through AOTI interface + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 3); + EXPECT_EQ(strides_ptr[1], 1); + + // Test another valid stride pattern - transpose-like + std::vector sizes_2 = {3, 2}; + std::vector strides_2 = {1, 3}; // Column-major strides + + Tensor* tensor_2 = create_tracked_tensor(sizes_2, strides_2); + EXPECT_NE(tensor_2, nullptr); + + // Verify the tensor properties + EXPECT_EQ(tensor_2->dim(), 2); + EXPECT_EQ(tensor_2->size(0), 3); + EXPECT_EQ(tensor_2->size(1), 2); + + // Check strides + int64_t* strides_ptr_2; + EXPECT_EQ(aoti_torch_get_strides(tensor_2, &strides_ptr_2), Error::Ok); + EXPECT_EQ(strides_ptr_2[0], 1); + EXPECT_EQ(strides_ptr_2[1], 3); +} + +// Test edge case: zero-element tensor with non-zero dimensions +TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) { + std::vector sizes = {2, 0, 3}; // Total elements = 0 + Tensor* tensor = create_tracked_tensor(sizes); + EXPECT_NE(tensor, nullptr); + + // Verify the tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 0); + EXPECT_EQ(tensor->size(2), 3); + + // Should be able to get metadata + int64_t* sizes_ptr; + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 0); + EXPECT_EQ(sizes_ptr[2], 3); +} + +// Test different data types (only float32 is currently supported) +TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) { + std::vector sizes = {2, 3}; + + // Test float32 (dtype 6) - currently the only supported type + Tensor* tensor_float32; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor_float32); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_float32, nullptr); + + // Test unsupported data types should return error + Tensor* tensor_int32; + error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 3, // int32 - unsupported + 1, // CUDA device + 0, // device index + &tensor_int32); + + EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype + + // Test another unsupported data type + Tensor* tensor_float64; + error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 7, // float64 - unsupported + 1, // CUDA device + 0, // device index + &tensor_float64); + + EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype +} + +// Test multi-dimensional tensors with various shapes +TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) { + // Test 3D tensor + std::vector sizes_3d = {2, 3, 4}; + Tensor* tensor_3d = create_tracked_tensor(sizes_3d); + EXPECT_NE(tensor_3d, nullptr); + EXPECT_EQ(tensor_3d->dim(), 3); + EXPECT_EQ(tensor_3d->size(0), 2); + EXPECT_EQ(tensor_3d->size(1), 3); + EXPECT_EQ(tensor_3d->size(2), 4); + + // Test 4D tensor + std::vector sizes_4d = {2, 3, 4, 5}; + Tensor* tensor_4d = create_tracked_tensor(sizes_4d); + EXPECT_NE(tensor_4d, nullptr); + EXPECT_EQ(tensor_4d->dim(), 4); + EXPECT_EQ(tensor_4d->size(0), 2); + EXPECT_EQ(tensor_4d->size(1), 3); + EXPECT_EQ(tensor_4d->size(2), 4); + EXPECT_EQ(tensor_4d->size(3), 5); + + // Test 5D tensor + std::vector sizes_5d = {1, 2, 3, 4, 5}; + Tensor* tensor_5d = create_tracked_tensor(sizes_5d); + EXPECT_NE(tensor_5d, nullptr); + EXPECT_EQ(tensor_5d->dim(), 5); + EXPECT_EQ(tensor_5d->size(0), 1); + EXPECT_EQ(tensor_5d->size(1), 2); + EXPECT_EQ(tensor_5d->size(2), 3); + EXPECT_EQ(tensor_5d->size(3), 4); + EXPECT_EQ(tensor_5d->size(4), 5); +} diff --git a/backends/cuda/runtime/shims/utils.h b/backends/cuda/runtime/shims/utils.h new file mode 100644 index 00000000000..23943391b50 --- /dev/null +++ b/backends/cuda/runtime/shims/utils.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +// Enum for supported data types in et-cuda backend +enum class SupportedDTypes : int32_t { + FLOAT32 = 6, // PyTorch's float32 dtype code + BFLOAT16 = 15, // PyTorch's bfloat16 dtype code +}; + +// Enum for supported device types in et-cuda backend +enum class SupportedDevices : int32_t { + CPU = 0, // CPU device + CUDA = 1, // CUDA device +}; + +// Utility function to convert sizes pointer to vector +inline std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +inline std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. it is ok if provided strides here is not contiguous + // strides since it will be used internally in CUDA delegate. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes using ExecutorTorch's algorithm + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +extern "C" { +using executorch::runtime::Error; +// Common AOTI type aliases +using AOTITorchError = Error; + +// Helper function to check if a dtype is supported in ET CUDA backend +inline bool is_dtype_supported_in_et_cuda(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Dtype validation utility function +inline AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_cuda(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch From 7a56420341e13b867a25654425776238b02f7f25 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Sun, 28 Sep 2025 13:46:55 -0700 Subject: [PATCH 2/5] tensor destroy (#14575) Summary: This diff introduce `aoti_torch_delete_tensor_object` for deleting tensors created during cuda backend inference. Reviewed By: larryliu0820 Differential Revision: D83094605 --- backends/cuda/runtime/shims/memory.cpp | 69 ++- backends/cuda/runtime/shims/memory.h | 10 +- backends/cuda/runtime/shims/tests/targets.bzl | 1 + .../test_aoti_torch_delete_tensor_object.cpp | 454 ++++++++++++++++++ 4 files changed, 532 insertions(+), 2 deletions(-) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 99d936e32ca..12a1d59e5e1 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -123,11 +123,78 @@ AOTITorchError aoti_torch_empty_strided( return Error::Ok; } -// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors void clear_all_tensors() { + // Use aoti_torch_delete_tensor_object to properly delete each tensor + // Note: We need to collect tensor pointers first since deletion modifies the + // set + auto old_tensors = + std::move(tensors); // tensors is now empty and no need to copy + for (const auto& tensor_shared : old_tensors) { + aoti_torch_delete_tensor_object(tensor_shared.get()); + } + + // tensors set should now be empty, but ensure it's cleared tensors.clear(); } +AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { + // Handle null tensor pointer + if (tensor == nullptr) { + ET_LOG(Error, "Cannot delete null tensor"); + return Error::InvalidArgument; + } + + // Check if tensor exists in our tracking + bool found_in_tensors = false; + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + found_in_tensors = true; + break; + } + } + + // If tensor not found in our tracking, it's invalid + if (!found_in_tensors) { + ET_LOG(Error, "Didn't find tensor %p", tensor); + return Error::InvalidArgument; + } + + // Find and delete the tensor + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + // Get the tensor before erasing + auto tensor_ptr = *it; + + void* data_ptr = tensor_ptr->mutable_data_ptr(); + + // Determine if it's GPU memory + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + // et tensor does not own data; need to free them manually. + if (attributes.type == cudaMemoryTypeManaged) { + // This is CUDA managed memory - free with proper synchronization + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaDeviceSynchronize()); // Wait for all operations to complete + // BEFORE freeing + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + } else { + // This is CPU memory - free immediately + free(data_ptr); + } + // Remove from set (this will call the destructor if it's the last + // reference) + tensors.erase(it); + return Error::Ok; + } + } + + // This should never be reached since we found it above + ET_LOG(Error, "Internal error: tensor not found after validation"); + return Error::Internal; +} + } // extern "C" } // namespace cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 2fdfdd8a72c..93bd9c30e70 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -44,8 +44,16 @@ AOTITorchError aoti_torch_empty_strided( int32_t device_index, Tensor** ret_new_tensor); +/** + * Deletes a tensor object and frees its associated memory. + * + * @param tensor Pointer to the tensor object to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); + // Function to clear all tensors from internal storage -// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors void clear_all_tensors(); } // extern "C" diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 5737bdb00ab..1db52ce1b97 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -28,3 +28,4 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ cuda_shim_cpp_unittest("aoti_torch_empty_strided") + cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp new file mode 100644 index 00000000000..eceb141e9ca --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp @@ -0,0 +1,454 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_delete_tensor_object tests +class AOTITorchDeleteTensorObjectTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_test_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = 6, // float32 + int32_t device_type = 1, // CUDA + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic deletion of CUDA tensor +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteCudaTensorBasic) { + // Create a CUDA tensor + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes, {}, 6, 1, 0); // CUDA device + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties before deletion + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test basic deletion of CPU tensor +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteCpuTensorBasic) { + // Create a CPU tensor + std::vector sizes = {3, 4}; + Tensor* tensor = create_test_tensor(sizes, {}, 6, 0, 0); // CPU device + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties before deletion + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of null tensor pointer +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteNullTensor) { + AOTITorchError error = aoti_torch_delete_tensor_object(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test deletion of tensor not in tracking system +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteUntrackedTensor) { + // Create a tensor and then clear the tracking system + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Clear the tracking system (simulating an untracked tensor) + clear_all_tensors(); + + // Try to delete the tensor - should fail + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test deletion of multiple tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteMultipleTensors) { + // Create multiple tensors + std::vector tensors; + + for (int i = 1; i <= 5; i++) { + std::vector sizes = {i, i + 1}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + tensors.push_back(tensor); + } + + // Delete all tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} + +// Test deletion of zero-sized tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteZeroSizedTensor) { + // Create a zero-sized tensor + std::vector sizes = {0, 5}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of scalar (0D) tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteScalarTensor) { + // Create a scalar tensor + std::vector sizes = {}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 0); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of large multi-dimensional tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteLargeTensor) { + // Create a large multi-dimensional tensor + std::vector sizes = {10, 20, 30}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 10); + EXPECT_EQ(tensor->size(1), 20); + EXPECT_EQ(tensor->size(2), 30); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of tensors with custom strides +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteTensorWithCustomStrides) { + // Create tensor with custom strides + std::vector sizes = {3, 4}; + std::vector strides = {4, 1}; // Row-major strides + Tensor* tensor = create_test_tensor(sizes, strides); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion after accessing tensor data +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteAfterDataAccess) { + // Create a tensor + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Access tensor data (this should not prevent deletion) + void* data_ptr = tensor->mutable_data_ptr(); + EXPECT_NE(data_ptr, nullptr); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test double deletion (should fail on second attempt) +TEST_F(AOTITorchDeleteTensorObjectTest, DoubleDeletion) { + // Create a tensor + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // First deletion should succeed + AOTITorchError error1 = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error1, Error::Ok); + + // Second deletion should fail (tensor no longer tracked) + AOTITorchError error2 = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error2, Error::InvalidArgument); +} + +// Test deletion of tensors on both CUDA and CPU devices +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteMixedDeviceTensors) { + // Create CUDA tensor + std::vector sizes = {2, 3}; + Tensor* cuda_tensor = create_test_tensor(sizes, {}, 6, 1, 0); + ASSERT_NE(cuda_tensor, nullptr); + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor(sizes, {}, 6, 0, 0); + ASSERT_NE(cpu_tensor, nullptr); + + // Delete both tensors + AOTITorchError cuda_error = aoti_torch_delete_tensor_object(cuda_tensor); + EXPECT_EQ(cuda_error, Error::Ok); + + AOTITorchError cpu_error = aoti_torch_delete_tensor_object(cpu_tensor); + EXPECT_EQ(cpu_error, Error::Ok); +} + +// Test memory consistency after deletion +TEST_F(AOTITorchDeleteTensorObjectTest, MemoryConsistencyAfterDeletion) { + // Create multiple tensors + std::vector tensors; + const int num_tensors = 10; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {i + 1, i + 2}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + tensors.push_back(tensor); + } + + // Delete every other tensor + for (int i = 0; i < num_tensors; i += 2) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensors[i]); + EXPECT_EQ(error, Error::Ok); + } + + // Delete remaining tensors + for (int i = 1; i < num_tensors; i += 2) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensors[i]); + EXPECT_EQ(error, Error::Ok); + } +} + +// Test stress deletion with many small tensors +TEST_F(AOTITorchDeleteTensorObjectTest, StressDeletionManySmallTensors) { + const int num_tensors = 100; + std::vector tensors; + + // Create many small tensors + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {1, 1}; // Minimal size + Tensor* tensor = create_test_tensor(sizes); + if (tensor != nullptr) { + tensors.push_back(tensor); + } + } + + // Delete all created tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} + +// Test CUDA synchronization during deletion +TEST_F(AOTITorchDeleteTensorObjectTest, CudaSynchronizationDuringDeletion) { + // Create a larger CUDA tensor to ensure memory allocation + std::vector sizes = {100, 100}; + Tensor* tensor = create_test_tensor(sizes, {}, 6, 1, 0); // CUDA device + ASSERT_NE(tensor, nullptr); + + // Delete the tensor (should handle synchronization internally) + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Verify CUDA state is still good + cudaError_t cuda_error = cudaGetLastError(); + EXPECT_EQ(cuda_error, cudaSuccess); +} + +// Test specific deletion of bfloat16 tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteBFloat16Tensor) { + // Test 1D bfloat16 tensor deletion + std::vector sizes_1d = {10}; + Tensor* tensor_bf16_1d = create_test_tensor( + sizes_1d, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_1d, nullptr); + + // Verify it's bfloat16 before deletion + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_1d, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << actual_dtype; + + // Verify element size (bfloat16 should be 2 bytes per element) + EXPECT_EQ(tensor_bf16_1d->element_size(), 2); + + // Delete the bfloat16 tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor_bf16_1d); + EXPECT_EQ(error, Error::Ok); + + // Test 2D bfloat16 tensor deletion with custom strides + std::vector sizes_2d = {4, 6}; + std::vector strides_2d = {6, 1}; // Row-major strides + Tensor* tensor_bf16_2d = create_test_tensor( + sizes_2d, + strides_2d, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_2d, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor_bf16_2d->dim(), 2); + EXPECT_EQ(tensor_bf16_2d->size(0), 4); + EXPECT_EQ(tensor_bf16_2d->size(1), 6); + EXPECT_EQ(tensor_bf16_2d->element_size(), 2); + + // Verify it's bfloat16 + int32_t dtype_2d; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_2d, &dtype_2d), Error::Ok); + EXPECT_EQ(dtype_2d, static_cast(SupportedDTypes::BFLOAT16)); + + // Delete the 2D bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_2d); + EXPECT_EQ(error, Error::Ok); + + // Test 3D bfloat16 tensor deletion + std::vector sizes_3d = {2, 3, 4}; + Tensor* tensor_bf16_3d = create_test_tensor( + sizes_3d, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_3d, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor_bf16_3d->dim(), 3); + EXPECT_EQ(tensor_bf16_3d->size(0), 2); + EXPECT_EQ(tensor_bf16_3d->size(1), 3); + EXPECT_EQ(tensor_bf16_3d->size(2), 4); + EXPECT_EQ(tensor_bf16_3d->element_size(), 2); + + // Verify memory size (2 * 3 * 4 * 2 bytes = 48 bytes) + size_t expected_memory = 2 * 3 * 4 * 2; + size_t actual_memory = + tensor_bf16_3d->numel() * tensor_bf16_3d->element_size(); + EXPECT_EQ(actual_memory, expected_memory); + + // Delete the 3D bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_3d); + EXPECT_EQ(error, Error::Ok); + + // Test bfloat16 scalar tensor (0D) deletion + std::vector scalar_sizes = {}; + Tensor* tensor_bf16_scalar = create_test_tensor( + scalar_sizes, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_scalar, nullptr); + + // Verify scalar tensor properties + EXPECT_EQ(tensor_bf16_scalar->dim(), 0); + EXPECT_EQ(tensor_bf16_scalar->numel(), 1); + EXPECT_EQ(tensor_bf16_scalar->element_size(), 2); + + // Delete the scalar bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_scalar); + EXPECT_EQ(error, Error::Ok); + + // Test zero-element bfloat16 tensor deletion + std::vector zero_sizes = {0, 5}; + Tensor* tensor_bf16_zero = create_test_tensor( + zero_sizes, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_zero, nullptr); + + // Verify zero-element tensor properties + EXPECT_EQ(tensor_bf16_zero->dim(), 2); + EXPECT_EQ(tensor_bf16_zero->size(0), 0); + EXPECT_EQ(tensor_bf16_zero->size(1), 5); + EXPECT_EQ(tensor_bf16_zero->numel(), 0); + EXPECT_EQ(tensor_bf16_zero->element_size(), 2); + + // Delete the zero-element bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_zero); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of mixed dtype tensors (float32 and bfloat16) From a7f238d39c15eb4b2b48802d300403d4560bd7fd Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Sun, 28 Sep 2025 13:46:55 -0700 Subject: [PATCH 3/5] aoti_torch_create_tensor_from_blob_v2 (#14604) Summary: This function introduce aoti_torch_create_tensor_from_blob_v2, a function that create tensor from data blob and custom stride and size. Worth to notice that unlike aoti_torch_empty_strided, the tensor created by aoti_torch_create_tensor_from_blob_v2 will not have the control of the memory blob. Therefore when we delete it, the memory will not be freed. Reviewed By: larryliu0820 Differential Revision: D83094602 --- backends/aoti/utils.h | 17 + backends/cuda/runtime/shims/memory.cpp | 187 ++++- backends/cuda/runtime/shims/memory.h | 39 +- backends/cuda/runtime/shims/tests/targets.bzl | 1 + ..._aoti_torch_create_tensor_from_blob_v2.cpp | 754 ++++++++++++++++++ backends/cuda/runtime/shims/utils.h | 20 + 6 files changed, 981 insertions(+), 37 deletions(-) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 22734935df2..1c872e08648 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -73,6 +73,23 @@ inline AOTITorchError validate_storage_offset(int64_t storage_offset) { return Error::Ok; } +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_tensor_contiguous( + int64_t ndim, + const int64_t* sizes, + const int64_t* strides) { + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + } // extern "C" } // namespace aoti diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 12a1d59e5e1..94f589aece6 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -15,29 +15,10 @@ #include #include // For posix_memalign #include +#include #include #include -// CUDA error checking macro -#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ - do { \ - const cudaError_t err = EXPR; \ - if (err == cudaSuccess) { \ - break; \ - } \ - ET_LOG( \ - Error, \ - "%s:%d CUDA error: %s", \ - __FILE__, \ - __LINE__, \ - cudaGetErrorString(err)); \ - return Error::Internal; \ - } while (0) - -// Kernel launch check macro -#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) - namespace executorch { namespace backends { namespace cuda { @@ -46,12 +27,122 @@ using executorch::aten::SizesType; using executorch::aten::StridesType; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; +using executorch::backends::aoti::validate_storage_offset; // Global storage for tensors and their metadata std::unordered_set> tensors; +// Reference counting for memory addresses +// Maps memory address to number of tensors using it +// Special value: NOT_OWN (-1) means tensor never owns the memory +constexpr int32_t NOT_OWN = -1; +std::unordered_map memory_to_n_tensor; + extern "C" { +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + // TODO(gasoonjia): verify given data is on the target device + (void)device_type; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + // Validate input parameters first + if (data == nullptr) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null"); + return Error::InvalidArgument; + } + + if (sizes_ptr == nullptr && ndim > 0) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null"); + return Error::InvalidArgument; + } + + if (ret_new_tensor == nullptr) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null"); + return Error::InvalidArgument; + } + + // Check that device_index is always 0 + if (device_index != 0) { + ET_LOG(Error, "device_index must be 0, got: %d", device_index); + return Error::InvalidArgument; + } + + // Validate dtype using SupportedDTypes from utils.h + AOTITorchError dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + // Storage offset must be 0 since from_blob cannot handle different offsets + AOTITorchError storage_offset_error = validate_storage_offset(storage_offset); + if (storage_offset_error != Error::Ok) { + return storage_offset_error; + } + + // Convert sizes to the format expected by ExecutorTorch using SizesType + std::vector sizes = + convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using the common helper function with StridesType + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create ExecutorTorch tensor that wraps the existing memory + // Note: We're NOT copying the data, just wrapping it + auto tensor = executorch::extension::from_blob( + data, // existing memory (don't copy!) + sizes, // tensor dimensions + strides, // tensor strides (allows different strides) + dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType + ); + + if (!tensor) { + ET_LOG(Error, "Failed to create tensor from blob"); + return Error::InvalidArgument; + } + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + + // Check if this memory address is already being tracked + auto memory_it = memory_to_n_tensor.find(data); + if (memory_it != memory_to_n_tensor.end()) { + ET_LOG( + Error, + "Memory address %p is already being tracked by another tensor", + data); + return Error::InvalidArgument; + } + + // Mark this memory as NOT_OWN since tensor created from blob never owns + // memory + memory_to_n_tensor[data] = NOT_OWN; + + return Error::Ok; +} + AOTITorchError aoti_torch_empty_strided( int64_t ndim, const int64_t* sizes_ptr, @@ -120,6 +211,9 @@ AOTITorchError aoti_torch_empty_strided( tensors.insert(tensor); *ret_new_tensor = tensor.get(); + // This tensor owns the memory it allocated, set reference count to 1 + memory_to_n_tensor[ptr] = 1; + return Error::Ok; } @@ -164,26 +258,47 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { if (it->get() == tensor) { // Get the tensor before erasing auto tensor_ptr = *it; - void* data_ptr = tensor_ptr->mutable_data_ptr(); - // Determine if it's GPU memory - cudaPointerAttributes attributes{}; - ET_CUDA_CHECK_OR_RETURN_ERROR( - cudaPointerGetAttributes(&attributes, data_ptr)); - - // et tensor does not own data; need to free them manually. - if (attributes.type == cudaMemoryTypeManaged) { - // This is CUDA managed memory - free with proper synchronization - ET_CUDA_CHECK_OR_RETURN_ERROR( - cudaDeviceSynchronize()); // Wait for all operations to complete - // BEFORE freeing - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + // Find the reference count for this memory address + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it != memory_to_n_tensor.end()) { + int32_t ref_count = memory_it->second; + + if (ref_count == NOT_OWN) { + // Tensor never owned the memory, skip freeing + // Just remove tensor from tracking + tensors.erase(it); + return Error::Ok; + } else if (ref_count == 1) { + // Only current tensor using this memory, free it + // Determine if it's GPU memory + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeManaged) { + // This is CUDA managed memory - free with proper synchronization + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize()); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + } else { + // This is CPU memory - free immediately + free(data_ptr); + data_ptr = nullptr; + } + + // Remove from memory tracking + memory_to_n_tensor.erase(memory_it); + } else if (ref_count > 1) { + // Other tensors still using this memory, just decrement count + memory_to_n_tensor[data_ptr] = ref_count - 1; + } } else { - // This is CPU memory - free immediately - free(data_ptr); + ET_LOG(Error, "Internal error: memory not found during deletion"); + return Error::Internal; } - // Remove from set (this will call the destructor if it's the last + + // Remove tensor from set (this will call the destructor if it's the last // reference) tensors.erase(it); return Error::Ok; diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 93bd9c30e70..7f4c56a8000 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -21,6 +21,44 @@ using executorch::backends::aoti::Tensor; extern "C" { +/** + * Creates a tensor object from an existing memory blob without copying the + * data. The tensor will wrap the provided memory and will not take ownership of + * it. When the tensor is deleted, the original memory will remain valid and + * must be freed by the caller. + * + * @param data Pointer to the memory blob to wrap (must not be null) + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes (using SizesType) + * @param strides_ptr Pointer to array of strides for each dimension (using + * StridesType, can be null for contiguous) + * @param storage_offset Storage offset (must be 0 for current implementation) + * @param dtype Data type identifier (supports FLOAT32 and BFLOAT16 from + * SupportedDTypes) + * @param device_type Device type (CPU=0, CUDA=1 from SupportedDevices) + * @param device_index Device index (must be 0 for current implementation) + * @param ret_new_tensor Output parameter for the created tensor (must not be + * null) + * @param layout Tensor layout identifier (0=strided) + * @param opaque_metadata Optional metadata pointer (can be null) + * @param opaque_metadata_size Size of opaque metadata in bytes + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + /** * Creates an uninitialized tensor with specified dimensions, strides, and * dtyper on either CPU or CUDA device. @@ -55,7 +93,6 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); // Function to clear all tensors from internal storage void clear_all_tensors(); - } // extern "C" } // namespace cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 1db52ce1b97..dce7d0be39c 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -29,3 +29,4 @@ def define_common_targets(): """ cuda_shim_cpp_unittest("aoti_torch_empty_strided") cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") + cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp new file mode 100644 index 00000000000..2cb12719782 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -0,0 +1,754 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_create_tensor_from_blob_v2 tests +class AOTITorchCreateTensorFromBlobV2Test : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + + // Clean up any allocated memory buffers + for (void* ptr : cuda_memory_buffers_) { + if (ptr) { + cudaError_t cuda_err = cudaFree(ptr); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Failed to free CUDA memory: " << cudaGetErrorString(cuda_err); + } + } + cuda_memory_buffers_.clear(); + + for (void* ptr : cpu_memory_buffers_) { + if (ptr) { + free(ptr); + } + } + cpu_memory_buffers_.clear(); + } + + // Helper to allocate CUDA memory and track it for cleanup + void* allocate_cuda_memory(size_t bytes) { + void* ptr; + cudaError_t err = cudaMallocManaged(&ptr, bytes); + if (err == cudaSuccess) { + cuda_memory_buffers_.push_back(ptr); + return ptr; + } + return nullptr; + } + + // Helper to allocate CPU memory and track it for cleanup + void* allocate_cpu_memory(size_t bytes) { + void* ptr; + int result = posix_memalign(&ptr, 16, bytes); // 16-byte aligned + if (result == 0 && ptr != nullptr) { + cpu_memory_buffers_.push_back(ptr); + return ptr; + } + return nullptr; + } + + // Helper to calculate number of elements from sizes + int64_t calculate_numel(const std::vector& sizes) { + int64_t numel = 1; + for (int64_t size : sizes) { + numel *= size; + } + return numel; + } + + // Helper to calculate contiguous strides from sizes + std::vector calculate_contiguous_strides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + + strides[sizes.size() - 1] = 1; + // Use int64_t and check for underflow to avoid unsigned integer wraparound + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; + } + + private: + std::vector cuda_memory_buffers_; + std::vector cpu_memory_buffers_; +}; + +// Test basic functionality with CUDA memory +TEST_F(AOTITorchCreateTensorFromBlobV2Test, BasicFunctionalityCUDA) { + // Test 1D tensor + std::vector sizes_1d = {5}; + std::vector strides_1d = calculate_contiguous_strides(sizes_1d); + + // Allocate CUDA memory + size_t bytes = calculate_numel(sizes_1d) * sizeof(float); + void* cuda_data = allocate_cuda_memory(bytes); + ASSERT_NE(cuda_data, nullptr); + + Tensor* tensor_1d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cuda_data, + sizes_1d.size(), + sizes_1d.data(), + strides_1d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_1d, + 0, // layout (strided) + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_1d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor_1d->mutable_data_ptr(); + EXPECT_EQ(tensor_data, cuda_data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor_1d); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) For CUDA memory, check that we can still access it (synchronously) + // after tensor deletion + float pattern_value = 42.0f; + cudaError_t cuda_err = cudaMemcpy( + cuda_data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, cuda_data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test basic functionality with CPU memory +TEST_F(AOTITorchCreateTensorFromBlobV2Test, BasicFunctionalityCPU) { + // Test 2D tensor + std::vector sizes_2d = {3, 4}; + std::vector strides_2d = calculate_contiguous_strides(sizes_2d); + + // Allocate CPU memory + size_t bytes = calculate_numel(sizes_2d) * sizeof(float); + void* cpu_data = allocate_cpu_memory(bytes); + ASSERT_NE(cpu_data, nullptr); + + Tensor* tensor_2d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cpu_data, + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), + 0, // device index + &tensor_2d, + 0, // layout (strided) + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_2d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_2d->dim(), 2); + EXPECT_EQ(tensor_2d->size(0), 3); + EXPECT_EQ(tensor_2d->size(1), 4); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor_2d->mutable_data_ptr(); + EXPECT_EQ(tensor_data, cpu_data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor_2d); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) For CPU memory, directly write and read to verify accessibility + float* float_ptr = reinterpret_cast(cpu_data); + float pattern_value = 42.0f; + *float_ptr = pattern_value; + EXPECT_EQ(*float_ptr, pattern_value) + << "Original CPU memory should still be accessible after tensor deletion"; +} + +// Test with invalid dtype +TEST_F(AOTITorchCreateTensorFromBlobV2Test, InvalidDtype) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + 999, // invalid dtype + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with non-zero storage offset (should fail since from_blob cannot handle +// offsets) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, NonZeroStorageOffset) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 1, // non-zero storage_offset (should fail since from_blob cannot handle + // offsets) + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with custom strides (using stride parameter but still contiguous) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, CustomContiguousStrides) { + std::vector sizes = {2, 3}; + // Use the correct contiguous strides but pass them explicitly + std::vector contiguous_strides = {3, 1}; // Proper contiguous strides + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + contiguous_strides.data(), // Explicitly pass contiguous strides + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Verify strides were properly set (we can check via aoti_torch_get_strides) + int64_t* tensor_strides; + error = aoti_torch_get_strides(tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 3); + EXPECT_EQ(tensor_strides[1], 1); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test with null data pointer +TEST_F(AOTITorchCreateTensorFromBlobV2Test, NullDataPointer) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + nullptr, // null data pointer + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test scalar tensor (0D) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, ScalarTensor) { + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + + size_t bytes = sizeof(float); // Single element + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 0); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test zero-sized tensor +TEST_F(AOTITorchCreateTensorFromBlobV2Test, ZeroSizedTensor) { + std::vector sizes = {0, 5}; // Zero elements + std::vector strides = calculate_contiguous_strides(sizes); + + // Even for zero-sized tensor, we need some memory allocated + size_t bytes = sizeof(float); // Minimum allocation + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test multi-dimensional tensors +TEST_F(AOTITorchCreateTensorFromBlobV2Test, MultiDimensionalTensors) { + // Test 3D tensor + std::vector sizes_3d = {2, 3, 4}; + std::vector strides_3d = calculate_contiguous_strides(sizes_3d); + + size_t bytes_3d = calculate_numel(sizes_3d) * sizeof(float); + void* data_3d = allocate_cuda_memory(bytes_3d); + ASSERT_NE(data_3d, nullptr); + + Tensor* tensor_3d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data_3d, + sizes_3d.size(), + sizes_3d.data(), + strides_3d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_3d, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_3d, nullptr); + EXPECT_EQ(tensor_3d->dim(), 3); + EXPECT_EQ(tensor_3d->size(0), 2); + EXPECT_EQ(tensor_3d->size(1), 3); + EXPECT_EQ(tensor_3d->size(2), 4); + + // Test 4D tensor + std::vector sizes_4d = {2, 3, 4, 5}; + std::vector strides_4d = calculate_contiguous_strides(sizes_4d); + + size_t bytes_4d = calculate_numel(sizes_4d) * sizeof(float); + void* data_4d = allocate_cuda_memory(bytes_4d); + ASSERT_NE(data_4d, nullptr); + + Tensor* tensor_4d; + error = aoti_torch_create_tensor_from_blob_v2( + data_4d, + sizes_4d.size(), + sizes_4d.data(), + strides_4d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_4d, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_4d, nullptr); + EXPECT_EQ(tensor_4d->dim(), 4); + EXPECT_EQ(tensor_4d->size(0), 2); + EXPECT_EQ(tensor_4d->size(1), 3); + EXPECT_EQ(tensor_4d->size(2), 4); + EXPECT_EQ(tensor_4d->size(3), 5); +} + +// Test tensor data pointer consistency +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DataPointerConsistency) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* original_data = allocate_cuda_memory(bytes); + ASSERT_NE(original_data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + original_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check that the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, original_data); +} + +// Test creating multiple tensors from different blobs +TEST_F(AOTITorchCreateTensorFromBlobV2Test, MultipleTensorsFromBlobs) { + const int num_tensors = 5; + std::vector tensors; + std::vector data_ptrs; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {i + 1, i + 2}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + data_ptrs.push_back(data); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + tensors.push_back(tensor); + + // Verify dimensions + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), i + 1); + EXPECT_EQ(tensor->size(1), i + 2); + + // Verify the tensor uses the correct data pointer + EXPECT_EQ(tensor->mutable_data_ptr(), data); + } + + // Verify all tensors have different data pointers + for (int i = 0; i < num_tensors; i++) { + EXPECT_EQ(tensors[i]->mutable_data_ptr(), data_ptrs[i]); + for (int j = i + 1; j < num_tensors; j++) { + EXPECT_NE(tensors[i]->mutable_data_ptr(), tensors[j]->mutable_data_ptr()); + } + } +} + +// Test deletion of tensor created from blob (should not free the original +// memory) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeletionDoesNotFreeOriginalMemory) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // The original memory should still be valid (we'll free it in teardown) + // We can't easily test if the memory is still valid without risking crashes, + // but the test should pass without issues if memory management is correct +} + +// Test with opaque metadata +TEST_F(AOTITorchCreateTensorFromBlobV2Test, WithOpaqueMetadata) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + // Create some opaque metadata + std::vector metadata = {0x01, 0x02, 0x03, 0x04}; + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + metadata.data(), // opaque_metadata + metadata.size()); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); +} + +// Test stress test with many small tensors from blobs +TEST_F(AOTITorchCreateTensorFromBlobV2Test, StressTestManySmallTensors) { + const int num_tensors = 50; // Reduced for reasonable test time + std::vector tensors; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {1, 1}; // Minimal size + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + if (data == nullptr) { + // Skip if we run out of memory + continue; + } + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + if (error == Error::Ok && tensor != nullptr) { + tensors.push_back(tensor); + + // Verify the tensor uses the correct data pointer + EXPECT_EQ(tensor->mutable_data_ptr(), data); + } + } + + // Delete all created tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} diff --git a/backends/cuda/runtime/shims/utils.h b/backends/cuda/runtime/shims/utils.h index 23943391b50..38e56ca45a1 100644 --- a/backends/cuda/runtime/shims/utils.h +++ b/backends/cuda/runtime/shims/utils.h @@ -14,6 +14,26 @@ #include #include +// CUDA error checking macro +#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + return Error::Internal; \ + } while (0) + +// Kernel launch check macro +#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) + namespace executorch { namespace backends { namespace cuda { From c4720dfed3947493a0e4a51525318ce1fc89321d Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Sun, 28 Sep 2025 13:46:55 -0700 Subject: [PATCH 4/5] aoti_torch__reinterpret_tensor (#14614) Summary: Introduced aoti_torch__reinterpret_tensor, which creates a new tensor view that reinterprets the same underlying memory with custom shape and strides. Reviewed By: larryliu0820 Differential Revision: D83094603 --- backends/cuda/runtime/shims/memory.cpp | 117 +++ backends/cuda/runtime/shims/memory.h | 25 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + .../test_aoti_torch__reinterpret_tensor.cpp | 810 ++++++++++++++++++ 4 files changed, 953 insertions(+) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 94f589aece6..498a31d42aa 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -25,6 +25,8 @@ namespace cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; +using executorch::backends::aoti::aoti_torch_get_device_index; +using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; @@ -310,6 +312,121 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { return Error::Internal; } +AOTITorchError aoti_torch__reinterpret_tensor( + Tensor* self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + Tensor** ret_new_tensor) { + // Validate input parameters first + if (self == nullptr) { + ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: self tensor is null"); + return Error::InvalidArgument; + } + + if (sizes_ptr == nullptr && ndim > 0) { + ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: sizes_ptr is null"); + return Error::InvalidArgument; + } + + if (ret_new_tensor == nullptr) { + ET_LOG( + Error, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); + return Error::InvalidArgument; + } + + // Check if storage_offset is not 0 - return error if not + AOTITorchError storage_offset_error = validate_storage_offset(storage_offset); + if (storage_offset_error != Error::Ok) { + return storage_offset_error; + } + + // Get the device info from the source tensor to perform device_index + // validation + int32_t device_type = 0; + int32_t device_index = 0; + AOTITorchError device_error = aoti_torch_get_device_type(self, &device_type); + if (device_error != Error::Ok) { + return device_error; + } + + device_error = aoti_torch_get_device_index(self, &device_index); + if (device_error != Error::Ok) { + return device_error; + } + + // Ensure device_index is always 0 + if (device_index != 0) { + ET_LOG(Error, "device_index must be 0, got: %d", device_index); + return Error::InvalidArgument; + } + + // Get the dtype from the source tensor + int32_t dtype = 0; + AOTITorchError dtype_error = aoti_torch_get_dtype(self, &dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + // Validate dtype using SupportedDTypes + dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + // Get the original data pointer from the source tensor + void* data_ptr = self->mutable_data_ptr(); + if (data_ptr == nullptr) { + ET_LOG(Error, "Source tensor has null data pointer"); + return Error::InvalidArgument; + } + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it == memory_to_n_tensor.end()) { + ET_LOG( + Error, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + return Error::InvalidArgument; + } + + // Convert sizes using utility function from utils.h + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using utility function from utils.h + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that reinterprets the same memory with different + // shape/strides This creates a view, not a copy - the data pointer is shared + std::shared_ptr tensor = executorch::extension::from_blob( + data_ptr, // Reuse the same memory from source tensor + sizes, // New sizes with explicit SizesType + strides, // New strides with explicit StridesType + dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting + ); + + if (!tensor) { + ET_LOG(Error, "Failed to create reinterpreted tensor view"); + return Error::InvalidArgument; + } + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} + } // extern "C" } // namespace cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 7f4c56a8000..4e9780840e1 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -91,6 +91,31 @@ AOTITorchError aoti_torch_empty_strided( */ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); +/** + * Creates a tensor view that reinterprets the same underlying memory with + * different shape and strides without copying data. + * + * Note that the new tensor will not have the ownership of the underlying + * memory. + * + * @param self Input tensor whose memory will be reinterpreted + * @param ndim Number of dimensions for the new tensor view + * @param sizes_ptr Array of sizes for each dimension + * @param strides_ptr Array of strides for each dimension (or nullptr for + * contiguous) + * @param storage_offset Storage offset (must be 0) + * @param ret_new_tensor Output pointer to store the new tensor view + * + * @return Error::Ok on success, appropriate error code on failure + */ +AOTITorchError aoti_torch__reinterpret_tensor( + Tensor* self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + Tensor** ret_new_tensor); + // Function to clear all tensors from internal storage void clear_all_tensors(); } // extern "C" diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index dce7d0be39c..ac6d2072d58 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -30,3 +30,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_empty_strided") cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") + cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp new file mode 100644 index 00000000000..ef00ecff656 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp @@ -0,0 +1,810 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch__reinterpret_tensor tests +class AOTITorchReinterpretTensorTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to calculate number of elements from sizes + int64_t calculate_numel(const std::vector& sizes) { + int64_t numel = 1; + for (int64_t size : sizes) { + numel *= size; + } + return numel; + } + + // Helper to calculate contiguous strides from sizes + std::vector calculate_contiguous_strides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; + } + + // Helper to create a source tensor using empty_strided (which allocates new + // memory) + Tensor* create_source_tensor( + const std::vector& sizes, + int32_t dtype = 6, // float32 + int32_t device_type = 1, // CUDA + int32_t device_index = 0) { + std::vector strides = calculate_contiguous_strides(sizes); + + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + dtype, + device_type, + device_index, + &tensor); + + if (error != Error::Ok) { + return nullptr; + } + + return tensor; + } + + private: + std::vector cuda_memory_buffers_; + std::vector cpu_memory_buffers_; +}; + +// Test basic functionality: reinterpret tensor with different shapes +TEST_F(AOTITorchReinterpretTensorTest, BasicReinterpretation) { + // Create a source tensor with shape [12] (1D with 12 elements) + std::vector source_sizes = {12}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + // Store the original data pointer + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Reinterpret as [3, 4] (2D with same number of elements) + std::vector new_sizes = {3, 4}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor has the new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 3); + EXPECT_EQ(reinterpreted_tensor->size(1), 4); + + // CRITICAL: Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted tensor should use the same memory as the source tensor"; + + // Write data through the original tensor and verify it's visible through the + // reinterpreted tensor + std::vector test_data = { + 1.0f, + 2.0f, + 3.0f, + 4.0f, + 5.0f, + 6.0f, + 7.0f, + 8.0f, + 9.0f, + 10.0f, + 11.0f, + 12.0f}; + cudaError_t cuda_err = cudaMemcpy( + original_data_ptr, + test_data.data(), + test_data.size() * sizeof(float), + cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Read back through the reinterpreted tensor + std::vector readback_data(12); + cuda_err = cudaMemcpy( + readback_data.data(), + reinterpreted_data_ptr, + readback_data.size() * sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Verify the data matches + for (size_t i = 0; i < test_data.size(); i++) { + EXPECT_EQ(readback_data[i], test_data[i]) + << "Data should be the same through both tensors at index " << i; + } +} + +// Test reinterpreting with different strides +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretWithCustomStrides) { + // Create a source tensor with shape [2, 6] (contiguous) + std::vector source_sizes = {2, 6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Reinterpret as [3, 4] with custom strides (still valid for the same memory) + std::vector new_sizes = {3, 4}; + std::vector new_strides = {4, 1}; // Row-major strides for [3, 4] + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 3); + EXPECT_EQ(reinterpreted_tensor->size(1), 4); + + // CRITICAL: Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted tensor should use the same memory as the source tensor"; + + // Verify strides were set correctly + int64_t* tensor_strides; + error = aoti_torch_get_strides(reinterpreted_tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 4); + EXPECT_EQ(tensor_strides[1], 1); +} + +// Test error cases: null input tensor +TEST_F(AOTITorchReinterpretTensorTest, NullInputTensor) { + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + nullptr, // null input tensor + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error cases: null sizes pointer +TEST_F(AOTITorchReinterpretTensorTest, NullSizesPointer) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + std::vector new_strides = {2, 1}; + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + 2, // ndim > 0 + nullptr, // null sizes pointer + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error cases: null return tensor pointer +TEST_F(AOTITorchReinterpretTensorTest, NullReturnTensorPointer) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + nullptr); // null return tensor pointer + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error cases: non-zero storage offset (should fail) +TEST_F(AOTITorchReinterpretTensorTest, NonZeroStorageOffset) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 1, // non-zero storage_offset (should fail) + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test reinterpreting CPU tensor +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretCPUTensor) { + // Create a CPU tensor with shape [8] + std::vector source_sizes = {8}; + Tensor* source_tensor = create_source_tensor( + source_sizes, + 6, // float32 + 0, // CPU device + 0); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Reinterpret as [2, 4] + std::vector new_sizes = {2, 4}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted CPU tensor should use the same memory as the source tensor"; + + // Test direct memory access for CPU tensors + float* original_float_ptr = reinterpret_cast(original_data_ptr); + float* reinterpreted_float_ptr = + reinterpret_cast(reinterpreted_data_ptr); + + // Write through original and read through reinterpreted + original_float_ptr[0] = 42.0f; + EXPECT_EQ(reinterpreted_float_ptr[0], 42.0f) + << "Changes through original tensor should be visible through reinterpreted tensor"; +} + +// Test that deleting source tensor doesn't affect reinterpreted tensor (they +// share memory) +TEST_F(AOTITorchReinterpretTensorTest, DeletionBehavior) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* shared_data_ptr = source_tensor->mutable_data_ptr(); + + // Reinterpret as [2, 3] + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Verify they share the same memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), shared_data_ptr); + + // Delete the source tensor (which owns the memory) + error = aoti_torch_delete_tensor_object(source_tensor); + EXPECT_EQ(error, Error::Ok); + + // The reinterpreted tensor should still be valid but the memory might be + // freed Since the source tensor owned the memory, the reinterpreted tensor + // becomes invalid This is expected behavior - the user needs to manage the + // lifecycle properly + + // Clean up the reinterpreted tensor + error = aoti_torch_delete_tensor_object(reinterpreted_tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test scalar tensor reinterpretation +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretScalarTensor) { + // Create a scalar tensor (0D) + std::vector source_sizes = {}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + + // Try to reinterpret scalar as [1] (1D with 1 element) + std::vector new_sizes = {1}; + std::vector new_strides = {1}; + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), original_data_ptr); + + // Check new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 1); + EXPECT_EQ(reinterpreted_tensor->size(0), 1); +} + +// Test reinterpreting tensor with zero-sized dimension +// TODO: This test is disabled because zero-sized tensors have complex stride +// validation requirements that need further investigation +TEST_F(AOTITorchReinterpretTensorTest, DISABLED_ReinterpretZeroSizedTensor) { + // Create a tensor with shape [0, 5] (zero elements) + std::vector source_sizes = {0, 5}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + + // Reinterpret as [5, 0] (still zero elements) + std::vector new_sizes = {5, 0}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), original_data_ptr); + + // Check new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 5); + EXPECT_EQ(reinterpreted_tensor->size(1), 0); +} + +// Test with nullptr strides (should use contiguous strides) +TEST_F(AOTITorchReinterpretTensorTest, NullStridesPointer) { + std::vector source_sizes = {12}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + + // Reinterpret as [3, 4] with null strides (should calculate contiguous + // strides) + std::vector new_sizes = {3, 4}; + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + nullptr, // null strides - should calculate contiguous strides + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), original_data_ptr); + + // Check that contiguous strides were calculated correctly + int64_t* tensor_strides; + error = aoti_torch_get_strides(reinterpreted_tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 4); // stride for dimension 0 should be 4 + EXPECT_EQ(tensor_strides[1], 1); // stride for dimension 1 should be 1 +} + +// Test bf16 tensor reinterpretation +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretBF16Tensor) { + // Create a bf16 source tensor with shape [6] + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor( + source_sizes, + static_cast( + SupportedDTypes::BFLOAT16), // bf16 dtype from SupportedDTypes + static_cast( + SupportedDevices::CUDA), // CUDA device from SupportedDevices + 0); // device_index must be 0 + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Verify the tensor is actually bf16 + int32_t actual_dtype = 0; + AOTITorchError dtype_check_error = + aoti_torch_get_dtype(source_tensor, &actual_dtype); + EXPECT_EQ(dtype_check_error, Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Source tensor should have bfloat16 dtype"; + + // Reinterpret as [2, 3] (same number of elements) + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor has the new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 2); + EXPECT_EQ(reinterpreted_tensor->size(1), 3); + + // Verify the dtype is preserved as bf16 + int32_t reinterpreted_dtype = 0; + dtype_check_error = + aoti_torch_get_dtype(reinterpreted_tensor, &reinterpreted_dtype); + EXPECT_EQ(dtype_check_error, Error::Ok); + EXPECT_EQ( + reinterpreted_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Reinterpreted tensor should preserve bfloat16 dtype"; + + // CRITICAL: Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted tensor should use the same memory as the source tensor"; + + // Test memory sharing by writing data through the original tensor + // and verifying it's visible through the reinterpreted tensor + // Note: bf16 has 2 bytes per element + std::vector test_data_bf16 = { + 0x3F80, 0x4000, 0x4040, 0x4080, 0x40A0, 0x40C0}; // bf16 values + cudaError_t cuda_err = cudaMemcpy( + original_data_ptr, + test_data_bf16.data(), + test_data_bf16.size() * sizeof(uint16_t), + cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Read back through the reinterpreted tensor + std::vector readback_data_bf16(6); + cuda_err = cudaMemcpy( + readback_data_bf16.data(), + reinterpreted_data_ptr, + readback_data_bf16.size() * sizeof(uint16_t), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Verify the data matches + for (size_t i = 0; i < test_data_bf16.size(); i++) { + EXPECT_EQ(readback_data_bf16[i], test_data_bf16[i]) + << "BF16 data should be the same through both tensors at index " << i; + } +} + +// Test reference counting behavior - memory not in map should fail +TEST_F(AOTITorchReinterpretTensorTest, MemoryNotInMapShouldFail) { + // Create a tensor directly without using our allocation functions + // This should NOT be in the reference counting map + void* external_memory; + ASSERT_EQ( + cudaMallocManaged(&external_memory, 12 * sizeof(float)), cudaSuccess); + + // Create a tensor by manually wrapping this memory without going through our + // APIs + std::vector sizes = {12}; + std::vector strides = calculate_contiguous_strides(sizes); + + // Create the tensor directly using ExecutorTorch extension + auto tensor_shared = executorch::extension::from_blob( + external_memory, + convert_sizes_to_vector(sizes.size(), sizes.data()), + convert_strides_to_vector(sizes.size(), sizes.data(), strides.data()), + executorch::runtime::etensor::ScalarType::Float); + + ASSERT_TRUE(tensor_shared); + Tensor* external_tensor = tensor_shared.get(); + + // Try to reinterpret this tensor - should fail because memory is not in map + std::vector new_sizes = {3, 4}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + external_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + // Should fail because memory is not being tracked by reference counting + // system + EXPECT_EQ(error, Error::InvalidArgument); + + // Clean up the external memory + ASSERT_EQ(cudaFree(external_memory), cudaSuccess); +} + +// Test reference counting behavior - creating view increments reference count +TEST_F(AOTITorchReinterpretTensorTest, ViewCreationIncrementsReferenceCount) { + // Create a source tensor that owns memory (reference count = 1) + std::vector source_sizes = {12}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* shared_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(shared_data_ptr, nullptr); + + // Create first view - should increment reference count to 2 + std::vector view1_sizes = {3, 4}; + std::vector view1_strides = + calculate_contiguous_strides(view1_sizes); + + Tensor* view1_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + view1_sizes.size(), + view1_sizes.data(), + view1_strides.data(), + 0, + &view1_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view1_tensor, nullptr); + EXPECT_EQ(view1_tensor->mutable_data_ptr(), shared_data_ptr); + + // Create second view - should increment reference count to 3 + std::vector view2_sizes = {2, 6}; + std::vector view2_strides = + calculate_contiguous_strides(view2_sizes); + + Tensor* view2_tensor; + error = aoti_torch__reinterpret_tensor( + source_tensor, + view2_sizes.size(), + view2_sizes.data(), + view2_strides.data(), + 0, + &view2_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view2_tensor, nullptr); + EXPECT_EQ(view2_tensor->mutable_data_ptr(), shared_data_ptr); + + // Now delete the source tensor - memory should NOT be freed (reference count + // = 2) + error = aoti_torch_delete_tensor_object(source_tensor); + EXPECT_EQ(error, Error::Ok); + + // Both views should still be valid - test by accessing memory + float test_value = 42.0f; + cudaError_t cuda_err = cudaMemcpy( + shared_data_ptr, &test_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view1_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete first view - memory should still NOT be freed (reference count = 1) + error = aoti_torch_delete_tensor_object(view1_tensor); + EXPECT_EQ(error, Error::Ok); + + // Second view should still be valid + readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view2_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete second view - NOW memory should be freed (reference count = 0) + error = aoti_torch_delete_tensor_object(view2_tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test reference counting behavior with NOT_OWN memory (from blob) - should +// SUCCEED and keep NOT_OWN +TEST_F(AOTITorchReinterpretTensorTest, ViewOfNotOwnMemoryKeepsNotOwnStatus) { + // Allocate external memory + void* external_memory; + cudaError_t cuda_err = + cudaMallocManaged(&external_memory, 12 * sizeof(float)); + ASSERT_EQ(cuda_err, cudaSuccess); + + // Create tensor from blob (which marks memory as NOT_OWN) + std::vector blob_sizes = {12}; + std::vector blob_strides = calculate_contiguous_strides(blob_sizes); + + Tensor* blob_tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + external_memory, + blob_sizes.size(), + blob_sizes.data(), + blob_strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device_index + &blob_tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(blob_tensor, nullptr); + + // Create view of NOT_OWN memory - should SUCCEED and keep NOT_OWN status + std::vector view_sizes = {3, 4}; + std::vector view_strides = calculate_contiguous_strides(view_sizes); + + Tensor* view_tensor; + error = aoti_torch__reinterpret_tensor( + blob_tensor, + view_sizes.size(), + view_sizes.data(), + view_strides.data(), + 0, + &view_tensor); + + // Should succeed - NOT_OWN memory can be reinterpreted but stays NOT_OWN + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->mutable_data_ptr(), external_memory); + + // Verify both tensors share the same memory + EXPECT_EQ(blob_tensor->mutable_data_ptr(), view_tensor->mutable_data_ptr()); + + // Test memory sharing by writing data through one tensor and reading through + // the other + float test_value = 42.0f; + cuda_err = cudaMemcpy( + external_memory, &test_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete the blob tensor - external memory should NOT be freed (NOT_OWN + // behavior) + error = aoti_torch_delete_tensor_object(blob_tensor); + EXPECT_EQ(error, Error::Ok); + + // View tensor should still be valid - test by accessing memory + readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete view tensor - external memory should still NOT be freed (NOT_OWN + // behavior) + error = aoti_torch_delete_tensor_object(view_tensor); + EXPECT_EQ(error, Error::Ok); + + // External memory should still be accessible (proves neither tensor freed it) + readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, external_memory, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Clean up external memory manually (as expected for NOT_OWN memory) + ASSERT_EQ(cudaFree(external_memory), cudaSuccess); +} From d961c1b1e01e0ee3a8821143a69d16dccc162661 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Sun, 28 Sep 2025 13:46:55 -0700 Subject: [PATCH 5/5] aoti_torch_copy_ (#14615) Summary: This diff introduce `aoti_torch_copy_`, the function for copying tensor inside cuda backend. Right now it only support copy between tensors with same dtype. Reviewed By: larryliu0820 Differential Revision: D83094604 --- backends/cuda/runtime/shims/memory.cpp | 271 +++++++++++- backends/cuda/runtime/shims/memory.h | 25 ++ backends/cuda/runtime/shims/tests/targets.bzl | 1 + .../shims/tests/test_aoti_torch_copy_.cpp | 398 ++++++++++++++++++ 4 files changed, 692 insertions(+), 3 deletions(-) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 498a31d42aa..b70a63f579a 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -27,6 +27,8 @@ using executorch::aten::SizesType; using executorch::aten::StridesType; using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; +using executorch::backends::aoti::aoti_torch_get_sizes; +using executorch::backends::aoti::aoti_torch_get_strides; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; @@ -40,6 +42,67 @@ std::unordered_set> tensors; constexpr int32_t NOT_OWN = -1; std::unordered_map memory_to_n_tensor; +namespace { + +// Calculate linear offset from strides and indices +int64_t calculate_linear_offset( + const int64_t* indices, + const int64_t* strides, + int64_t ndim) { + int64_t offset = 0; + for (int64_t i = 0; i < ndim; ++i) { + offset += indices[i] * strides[i]; + } + return offset; +} + +// Convert linear index to multi-dimensional indices based on sizes +void linear_to_indices( + int64_t linear_idx, + const int64_t* sizes, + int64_t ndim, + int64_t* indices) { + for (int64_t i = ndim - 1; i >= 0; --i) { + indices[i] = linear_idx % sizes[i]; + linear_idx /= sizes[i]; + } +} + +// Generic pointwise copy function that handles arbitrary strides +template +AOTITorchError pointwise_copy_generic( + T* dst_data, + const T* src_data, + const int64_t* dst_sizes, + const int64_t* dst_strides, + const int64_t* src_sizes, + const int64_t* src_strides, + int64_t dst_ndim, + int64_t src_ndim, + int64_t total_elements) { + std::vector dst_indices(dst_ndim); + std::vector src_indices(src_ndim); + + for (int64_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + // Convert linear index to multi-dimensional indices for both tensors + linear_to_indices(linear_idx, dst_sizes, dst_ndim, dst_indices.data()); + linear_to_indices(linear_idx, src_sizes, src_ndim, src_indices.data()); + + // Calculate offsets for both source and destination + int64_t src_offset = + calculate_linear_offset(src_indices.data(), src_strides, src_ndim); + int64_t dst_offset = + calculate_linear_offset(dst_indices.data(), dst_strides, dst_ndim); + + // Copy element + dst_data[dst_offset] = src_data[src_offset]; + } + + return Error::Ok; +} + +} // anonymous namespace + extern "C" { AOTITorchError aoti_torch_create_tensor_from_blob_v2( @@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided( } int64_t nbytes = numel * element_size; - if (device_type == 1) { // cuda - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes)); - } else if (device_type == 0) { // cpu + if (device_type == static_cast(SupportedDevices::CUDA)) { + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaMallocManaged(&ptr, static_cast(nbytes))); + } else if (device_type == static_cast(SupportedDevices::CPU)) { // Ensure 16-byte alignment for CPU memory to match CUDA requirements int result = posix_memalign(&ptr, 16, nbytes); if (result != 0) { @@ -312,6 +376,207 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { return Error::Internal; } +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { + (void)non_blocking; + + // Check for null pointers first + if (self == nullptr) { + ET_LOG(Error, "aoti_torch_copy_ failed: self tensor is null"); + return Error::InvalidArgument; + } + + if (src == nullptr) { + ET_LOG(Error, "aoti_torch_copy_ failed: src tensor is null"); + return Error::InvalidArgument; + } + + // Get dtype information and validate compatibility + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + AOTITorchError self_dtype_error = validate_dtype(self_dtype); + if (self_dtype_error != Error::Ok) { + return self_dtype_error; + } + + AOTITorchError src_dtype_error = validate_dtype(src_dtype); + if (src_dtype_error != Error::Ok) { + return src_dtype_error; + } + + // Check dtype compatibility - both tensors must have the same dtype + if (self_dtype != src_dtype) { + ET_LOG( + Error, + "dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes", + self_dtype, + src_dtype); + return Error::InvalidArgument; + } + + // Check total number of elements compatibility (PyTorch copy_ behavior) + int64_t self_numel = self->numel(); + int64_t src_numel = src->numel(); + + if (self_numel != src_numel) { + ET_LOG( + Error, + "numel mismatch. self.numel()=%ld, src.numel()=%ld", + self_numel, + src_numel); + return Error::InvalidArgument; + } + + // Get tensor metadata + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Determine device locations + cudaPointerAttributes srcAttributes{}; + cudaPointerAttributes dstAttributes{}; + + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&srcAttributes, src->data_ptr())); + + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&dstAttributes, self->data_ptr())); + + bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice; + bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice; + + // Check if tensors have the same schema (sizes, strides, dtype) for fast path + bool same_schema = true; + for (int i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + size_t total_bytes = src->nbytes(); + int64_t total_elements = self->numel(); + + if (same_schema) { + // Fast path: Direct memory copy since layouts match exactly + if (srcIsDevice && dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToDevice)); + } else if (srcIsDevice && !dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToHost)); + } else if (!srcIsDevice && dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyHostToDevice)); + } else { + std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes); + } + } else { + // Fallback path: Pointwise copy with stride-aware indexing + // This handles arbitrary tensor layouts and strides + + size_t element_size = dtype_to_element_size(self_dtype); + if (element_size == 0) { + ET_LOG(Error, "Invalid element size for dtype: %d", self_dtype); + return Error::InvalidArgument; + } + + // Allocate temporary host memory for GPU tensors + float* src_host_data = nullptr; + float* dst_host_data = nullptr; + bool need_free_src = false; + bool need_free_dst = false; + + if (srcIsDevice) { + src_host_data = + static_cast(malloc(total_elements * sizeof(float))); + if (src_host_data == nullptr) { + ET_LOG(Error, "Failed to allocate memory for src_host_data"); + return Error::MemoryAllocationFailed; + } + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost)); + need_free_src = true; + } else { + src_host_data = static_cast(src->data_ptr()); + } + + if (dstIsDevice) { + dst_host_data = + static_cast(malloc(total_elements * sizeof(float))); + if (dst_host_data == nullptr) { + ET_LOG(Error, "Failed to allocate memory for dst_host_data"); + if (need_free_src) { + free(src_host_data); + } + return Error::MemoryAllocationFailed; + } + need_free_dst = true; + } else { + dst_host_data = static_cast(self->mutable_data_ptr()); + } + + // Perform pointwise copy with stride calculation + AOTITorchError copy_err = pointwise_copy_generic( + dst_host_data, + src_host_data, + self_sizes, + self_strides, + src_sizes, + src_strides, + self->dim(), + src->dim(), + total_elements); + + if (copy_err != Error::Ok) { + // Clean up temporary buffers before returning + if (need_free_src) { + free(src_host_data); + } + if (need_free_dst) { + free(dst_host_data); + } + return copy_err; + } + + // Copy result back to device if needed + if (dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + dst_host_data, + total_bytes, + cudaMemcpyHostToDevice)); + } + + // Clean up temporary buffers + if (need_free_src) { + free(src_host_data); + } + if (need_free_dst) { + free(dst_host_data); + } + } + + return Error::Ok; +} + AOTITorchError aoti_torch__reinterpret_tensor( Tensor* self, int64_t ndim, diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 4e9780840e1..bcec6621285 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -116,6 +116,31 @@ AOTITorchError aoti_torch__reinterpret_tensor( int64_t storage_offset, Tensor** ret_new_tensor); +/** + * Copies data from source tensor to destination tensor. + * + * This function implements copy function for tensors living in CUDA AOTI + * backend. It supports copying between tensors with different shapes (as long + * as they have the same total number of elements) and different memory + * layouts/strides. + * + * Note that currently this function does not support copying between tensors + * with different dtypes. + * + * @param self Destination tensor (data will be overwritten) + * @param src Source tensor (data will be copied from this tensor) + * @param non_blocking Whether the copy should be non-blocking (currently + * ignored) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers, dtype mismatch, numel + * mismatch + * - Error::MemoryAllocationFailed: failed to allocate temporary memory + * - Error::Internal: CUDA operation failures + */ +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); + // Function to clear all tensors from internal storage void clear_all_tensors(); } // extern "C" diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index ac6d2072d58..fcb95a0beb7 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -31,3 +31,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") + cuda_shim_cpp_unittest("aoti_torch_copy_") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp new file mode 100644 index 00000000000..7579eaef039 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for aoti_torch_copy_ tests +class AOTITorchCopyTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors with specific data + Tensor* create_test_tensor_with_data( + const std::vector& sizes, + const std::vector& data, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill tensor with data + size_t total_bytes = data.size() * sizeof(float); + if (device_type == static_cast(SupportedDevices::CUDA)) { + cudaError_t memcpy_err = cudaMemcpy( + tensor->mutable_data_ptr(), + data.data(), + total_bytes, + cudaMemcpyHostToDevice); + // Note: Error is checked but we don't fail the function + // This allows tests to proceed and handle errors as needed + (void)memcpy_err; // Suppress unused variable warning + } else { // CPU + std::memcpy(tensor->mutable_data_ptr(), data.data(), total_bytes); + } + + return tensor; + } + + // Helper to get data from tensor + std::vector get_tensor_data(Tensor* tensor) { + if (!tensor) { + return {}; + } + + size_t num_elements = tensor->numel(); + std::vector data(num_elements); + + // Determine if this is a CUDA tensor + cudaPointerAttributes attributes{}; + cudaError_t err = cudaPointerGetAttributes(&attributes, tensor->data_ptr()); + bool is_device = + (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice); + + if (is_device) { + cudaError_t memcpy_err = cudaMemcpy( + data.data(), + tensor->data_ptr(), + num_elements * sizeof(float), + cudaMemcpyDeviceToHost); + // Note: Error is checked but we don't fail the function + // This allows tests to proceed and handle errors as needed + (void)memcpy_err; // Suppress unused variable warning + } else { + std::memcpy( + data.data(), tensor->data_ptr(), num_elements * sizeof(float)); + } + + return data; + } + + // Helper to verify two tensors have same data + bool tensors_equal(Tensor* a, Tensor* b, float tolerance = 1e-6f) { + if (!a || !b) { + return false; + } + if (a->numel() != b->numel()) { + return false; + } + + auto data_a = get_tensor_data(a); + auto data_b = get_tensor_data(b); + + for (size_t i = 0; i < data_a.size(); ++i) { + if (std::abs(data_a[i] - data_b[i]) > tolerance) { + return false; + } + } + return true; + } +}; + +// Test basic copy functionality - same schema (fast path) +TEST_F(AOTITorchCopyTest, BasicCopySameSchema) { + // Create source tensor with test data + std::vector sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* src = create_test_tensor_with_data(sizes, src_data); + EXPECT_NE(src, nullptr); + + // Create destination tensor with same schema + Tensor* dst = + create_test_tensor_with_data(sizes, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + EXPECT_NE(dst, nullptr); + + // Perform copy + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy was successful + EXPECT_TRUE(tensors_equal(dst, src)); +} + +// Test copy with different strides (pointwise fallback) +TEST_F(AOTITorchCopyTest, CopyDifferentStrides) { + // Create source tensor (2x3) with contiguous layout + std::vector src_sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* src = create_test_tensor_with_data(src_sizes, src_data); + EXPECT_NE(src, nullptr); + + // Create destination tensor with transposed strides + std::vector dst_strides = {1, 2}; // Column-major layout + Tensor* dst = create_test_tensor_with_data( + src_sizes, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, dst_strides); + EXPECT_NE(dst, nullptr); + + // Perform copy - this should use pointwise fallback + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify the copy worked correctly by checking specific elements + auto dst_data = get_tensor_data(dst); + auto src_data_check = get_tensor_data(src); + + // For transposed layout, the data should be rearranged + EXPECT_EQ(dst_data.size(), 6); + EXPECT_EQ(src_data_check.size(), 6); +} + +// Test copy between CPU and CUDA tensors +TEST_F(AOTITorchCopyTest, CopyCPUToCUDA) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor_with_data( + sizes, + data, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); // CPU + EXPECT_NE(cpu_tensor, nullptr); + + // Create CUDA tensor + Tensor* cuda_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); // CUDA + EXPECT_NE(cuda_tensor, nullptr); + + // Copy from CPU to CUDA + AOTITorchError error = aoti_torch_copy_(cuda_tensor, cpu_tensor, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy + EXPECT_TRUE(tensors_equal(cuda_tensor, cpu_tensor)); +} + +// Test copy between CUDA and CPU tensors +TEST_F(AOTITorchCopyTest, CopyCUDAToCPU) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create CUDA tensor + Tensor* cuda_tensor = create_test_tensor_with_data( + sizes, + data, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); // CUDA + EXPECT_NE(cuda_tensor, nullptr); + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); // CPU + EXPECT_NE(cpu_tensor, nullptr); + + // Copy from CUDA to CPU + AOTITorchError error = aoti_torch_copy_(cpu_tensor, cuda_tensor, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy + EXPECT_TRUE(tensors_equal(cpu_tensor, cuda_tensor)); +} + +// Test copy with bf16 dtype support +TEST_F(AOTITorchCopyTest, CopyBf16Tensors) { + // Test that bf16 tensors can be created and copied + std::vector sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Note: We create float32 data but the tensor will be created with bf16 dtype + // This simulates creating bf16 tensors + Tensor* src = create_test_tensor_with_data( + sizes, + src_data, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(src, nullptr); + + // Create destination tensor with bf16 dtype + std::vector dst_init(6, 0.0f); + Tensor* dst = create_test_tensor_with_data( + sizes, + dst_init, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(dst, nullptr); + + // Perform copy between bf16 tensors + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify that both tensors have the expected dtype + int32_t src_dtype, dst_dtype; + aoti_torch_get_dtype(src, &src_dtype); + aoti_torch_get_dtype(dst, &dst_dtype); + + EXPECT_EQ(src_dtype, static_cast(SupportedDTypes::BFLOAT16)); + EXPECT_EQ(dst_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Verify copy was successful by checking numel matches + EXPECT_EQ(src->numel(), dst->numel()); + EXPECT_EQ(src->numel(), 6); +} + +// Test copy between different dtypes should fail +TEST_F(AOTITorchCopyTest, CopyDTypeMismatchError) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create float32 tensor + Tensor* float32_tensor = create_test_tensor_with_data( + sizes, + data, + {}, // default strides + static_cast(SupportedDTypes::FLOAT32), // float32 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(float32_tensor, nullptr); + + // Create bf16 tensor + Tensor* bf16_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(bf16_tensor, nullptr); + + // Attempting to copy between different dtypes should fail + AOTITorchError error = aoti_torch_copy_(bf16_tensor, float32_tensor, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + // Reverse direction should also fail + error = aoti_torch_copy_(float32_tensor, bf16_tensor, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error conditions +TEST_F(AOTITorchCopyTest, ErrorHandling) { + std::vector sizes = {2, 3}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* valid_tensor = create_test_tensor_with_data(sizes, data); + EXPECT_NE(valid_tensor, nullptr); + + // Test null pointers + AOTITorchError error = aoti_torch_copy_(nullptr, valid_tensor, 0); + EXPECT_NE(error, Error::Ok); + + error = aoti_torch_copy_(valid_tensor, nullptr, 0); + EXPECT_NE(error, Error::Ok); + + // Test numel mismatch (different total number of elements) + std::vector different_numel_sizes = { + 2, 3, 4}; // 24 elements vs 6 elements + std::vector different_data(24, 1.0f); + Tensor* different_numel = + create_test_tensor_with_data(different_numel_sizes, different_data); + EXPECT_NE(different_numel, nullptr); + + error = aoti_torch_copy_(valid_tensor, different_numel, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test copy from 1D to 3D with same total elements +TEST_F(AOTITorchCopyTest, Copy1DTo3DSameNumel) { + // Source tensor: 8 elements in 1D + std::vector src_sizes = {8}; + std::vector src_data = { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + Tensor* src = create_test_tensor_with_data(src_sizes, src_data); + EXPECT_NE(src, nullptr); + + // Destination tensor: 2x2x2 = 8 elements (different shape, same total) + std::vector dst_sizes = {2, 2, 2}; + std::vector dst_init(8, 0.0f); + Tensor* dst = create_test_tensor_with_data(dst_sizes, dst_init); + EXPECT_NE(dst, nullptr); + + // This should work - same total number of elements + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify the data was copied correctly + auto dst_data = get_tensor_data(dst); + EXPECT_EQ(dst_data.size(), 8); + + // Check some specific elements to verify correct copying + EXPECT_FLOAT_EQ(dst_data[0], 1.0f); + EXPECT_FLOAT_EQ(dst_data[7], 8.0f); +}