From 975fb6b7eea35566f3b97c931fb4c83e55ea1dac Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 30 Sep 2025 01:10:08 -0700 Subject: [PATCH] aoti_torch_create_tensor_from_blob_v2 Summary: This function introduce aoti_torch_create_tensor_from_blob_v2, a function that create tensor from data blob and custom stride and size. Worth to notice that unlike aoti_torch_empty_strided, the tensor created by aoti_torch_create_tensor_from_blob_v2 will not have the control of the memory blob. Therefore when we delete it, the memory will not be freed. Reviewed By: Differential Revision: [ghstack-poisoned] --- backends/aoti/utils.h | 17 + backends/cuda/runtime/shims/memory.cpp | 187 ++++- backends/cuda/runtime/shims/memory.h | 39 +- backends/cuda/runtime/shims/tests/targets.bzl | 1 + ..._aoti_torch_create_tensor_from_blob_v2.cpp | 754 ++++++++++++++++++ backends/cuda/runtime/shims/utils.h | 20 + 6 files changed, 981 insertions(+), 37 deletions(-) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 22734935df2..1c872e08648 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -73,6 +73,23 @@ inline AOTITorchError validate_storage_offset(int64_t storage_offset) { return Error::Ok; } +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_tensor_contiguous( + int64_t ndim, + const int64_t* sizes, + const int64_t* strides) { + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + } // extern "C" } // namespace aoti diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 12a1d59e5e1..94f589aece6 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -15,29 +15,10 @@ #include #include // For posix_memalign #include +#include #include #include -// CUDA error checking macro -#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ - do { \ - const cudaError_t err = EXPR; \ - if (err == cudaSuccess) { \ - break; \ - } \ - ET_LOG( \ - Error, \ - "%s:%d CUDA error: %s", \ - __FILE__, \ - __LINE__, \ - cudaGetErrorString(err)); \ - return Error::Internal; \ - } while (0) - -// Kernel launch check macro -#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) - namespace executorch { namespace backends { namespace cuda { @@ -46,12 +27,122 @@ using executorch::aten::SizesType; using executorch::aten::StridesType; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; +using executorch::backends::aoti::validate_storage_offset; // Global storage for tensors and their metadata std::unordered_set> tensors; +// Reference counting for memory addresses +// Maps memory address to number of tensors using it +// Special value: NOT_OWN (-1) means tensor never owns the memory +constexpr int32_t NOT_OWN = -1; +std::unordered_map memory_to_n_tensor; + extern "C" { +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + // TODO(gasoonjia): verify given data is on the target device + (void)device_type; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + // Validate input parameters first + if (data == nullptr) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null"); + return Error::InvalidArgument; + } + + if (sizes_ptr == nullptr && ndim > 0) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null"); + return Error::InvalidArgument; + } + + if (ret_new_tensor == nullptr) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null"); + return Error::InvalidArgument; + } + + // Check that device_index is always 0 + if (device_index != 0) { + ET_LOG(Error, "device_index must be 0, got: %d", device_index); + return Error::InvalidArgument; + } + + // Validate dtype using SupportedDTypes from utils.h + AOTITorchError dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + // Storage offset must be 0 since from_blob cannot handle different offsets + AOTITorchError storage_offset_error = validate_storage_offset(storage_offset); + if (storage_offset_error != Error::Ok) { + return storage_offset_error; + } + + // Convert sizes to the format expected by ExecutorTorch using SizesType + std::vector sizes = + convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using the common helper function with StridesType + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create ExecutorTorch tensor that wraps the existing memory + // Note: We're NOT copying the data, just wrapping it + auto tensor = executorch::extension::from_blob( + data, // existing memory (don't copy!) + sizes, // tensor dimensions + strides, // tensor strides (allows different strides) + dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType + ); + + if (!tensor) { + ET_LOG(Error, "Failed to create tensor from blob"); + return Error::InvalidArgument; + } + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + + // Check if this memory address is already being tracked + auto memory_it = memory_to_n_tensor.find(data); + if (memory_it != memory_to_n_tensor.end()) { + ET_LOG( + Error, + "Memory address %p is already being tracked by another tensor", + data); + return Error::InvalidArgument; + } + + // Mark this memory as NOT_OWN since tensor created from blob never owns + // memory + memory_to_n_tensor[data] = NOT_OWN; + + return Error::Ok; +} + AOTITorchError aoti_torch_empty_strided( int64_t ndim, const int64_t* sizes_ptr, @@ -120,6 +211,9 @@ AOTITorchError aoti_torch_empty_strided( tensors.insert(tensor); *ret_new_tensor = tensor.get(); + // This tensor owns the memory it allocated, set reference count to 1 + memory_to_n_tensor[ptr] = 1; + return Error::Ok; } @@ -164,26 +258,47 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { if (it->get() == tensor) { // Get the tensor before erasing auto tensor_ptr = *it; - void* data_ptr = tensor_ptr->mutable_data_ptr(); - // Determine if it's GPU memory - cudaPointerAttributes attributes{}; - ET_CUDA_CHECK_OR_RETURN_ERROR( - cudaPointerGetAttributes(&attributes, data_ptr)); - - // et tensor does not own data; need to free them manually. - if (attributes.type == cudaMemoryTypeManaged) { - // This is CUDA managed memory - free with proper synchronization - ET_CUDA_CHECK_OR_RETURN_ERROR( - cudaDeviceSynchronize()); // Wait for all operations to complete - // BEFORE freeing - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + // Find the reference count for this memory address + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it != memory_to_n_tensor.end()) { + int32_t ref_count = memory_it->second; + + if (ref_count == NOT_OWN) { + // Tensor never owned the memory, skip freeing + // Just remove tensor from tracking + tensors.erase(it); + return Error::Ok; + } else if (ref_count == 1) { + // Only current tensor using this memory, free it + // Determine if it's GPU memory + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeManaged) { + // This is CUDA managed memory - free with proper synchronization + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize()); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + } else { + // This is CPU memory - free immediately + free(data_ptr); + data_ptr = nullptr; + } + + // Remove from memory tracking + memory_to_n_tensor.erase(memory_it); + } else if (ref_count > 1) { + // Other tensors still using this memory, just decrement count + memory_to_n_tensor[data_ptr] = ref_count - 1; + } } else { - // This is CPU memory - free immediately - free(data_ptr); + ET_LOG(Error, "Internal error: memory not found during deletion"); + return Error::Internal; } - // Remove from set (this will call the destructor if it's the last + + // Remove tensor from set (this will call the destructor if it's the last // reference) tensors.erase(it); return Error::Ok; diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 93bd9c30e70..7f4c56a8000 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -21,6 +21,44 @@ using executorch::backends::aoti::Tensor; extern "C" { +/** + * Creates a tensor object from an existing memory blob without copying the + * data. The tensor will wrap the provided memory and will not take ownership of + * it. When the tensor is deleted, the original memory will remain valid and + * must be freed by the caller. + * + * @param data Pointer to the memory blob to wrap (must not be null) + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes (using SizesType) + * @param strides_ptr Pointer to array of strides for each dimension (using + * StridesType, can be null for contiguous) + * @param storage_offset Storage offset (must be 0 for current implementation) + * @param dtype Data type identifier (supports FLOAT32 and BFLOAT16 from + * SupportedDTypes) + * @param device_type Device type (CPU=0, CUDA=1 from SupportedDevices) + * @param device_index Device index (must be 0 for current implementation) + * @param ret_new_tensor Output parameter for the created tensor (must not be + * null) + * @param layout Tensor layout identifier (0=strided) + * @param opaque_metadata Optional metadata pointer (can be null) + * @param opaque_metadata_size Size of opaque metadata in bytes + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + /** * Creates an uninitialized tensor with specified dimensions, strides, and * dtyper on either CPU or CUDA device. @@ -55,7 +93,6 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); // Function to clear all tensors from internal storage void clear_all_tensors(); - } // extern "C" } // namespace cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 1db52ce1b97..dce7d0be39c 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -29,3 +29,4 @@ def define_common_targets(): """ cuda_shim_cpp_unittest("aoti_torch_empty_strided") cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") + cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp new file mode 100644 index 00000000000..2cb12719782 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -0,0 +1,754 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_create_tensor_from_blob_v2 tests +class AOTITorchCreateTensorFromBlobV2Test : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + + // Clean up any allocated memory buffers + for (void* ptr : cuda_memory_buffers_) { + if (ptr) { + cudaError_t cuda_err = cudaFree(ptr); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Failed to free CUDA memory: " << cudaGetErrorString(cuda_err); + } + } + cuda_memory_buffers_.clear(); + + for (void* ptr : cpu_memory_buffers_) { + if (ptr) { + free(ptr); + } + } + cpu_memory_buffers_.clear(); + } + + // Helper to allocate CUDA memory and track it for cleanup + void* allocate_cuda_memory(size_t bytes) { + void* ptr; + cudaError_t err = cudaMallocManaged(&ptr, bytes); + if (err == cudaSuccess) { + cuda_memory_buffers_.push_back(ptr); + return ptr; + } + return nullptr; + } + + // Helper to allocate CPU memory and track it for cleanup + void* allocate_cpu_memory(size_t bytes) { + void* ptr; + int result = posix_memalign(&ptr, 16, bytes); // 16-byte aligned + if (result == 0 && ptr != nullptr) { + cpu_memory_buffers_.push_back(ptr); + return ptr; + } + return nullptr; + } + + // Helper to calculate number of elements from sizes + int64_t calculate_numel(const std::vector& sizes) { + int64_t numel = 1; + for (int64_t size : sizes) { + numel *= size; + } + return numel; + } + + // Helper to calculate contiguous strides from sizes + std::vector calculate_contiguous_strides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + + strides[sizes.size() - 1] = 1; + // Use int64_t and check for underflow to avoid unsigned integer wraparound + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; + } + + private: + std::vector cuda_memory_buffers_; + std::vector cpu_memory_buffers_; +}; + +// Test basic functionality with CUDA memory +TEST_F(AOTITorchCreateTensorFromBlobV2Test, BasicFunctionalityCUDA) { + // Test 1D tensor + std::vector sizes_1d = {5}; + std::vector strides_1d = calculate_contiguous_strides(sizes_1d); + + // Allocate CUDA memory + size_t bytes = calculate_numel(sizes_1d) * sizeof(float); + void* cuda_data = allocate_cuda_memory(bytes); + ASSERT_NE(cuda_data, nullptr); + + Tensor* tensor_1d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cuda_data, + sizes_1d.size(), + sizes_1d.data(), + strides_1d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_1d, + 0, // layout (strided) + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_1d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor_1d->mutable_data_ptr(); + EXPECT_EQ(tensor_data, cuda_data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor_1d); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) For CUDA memory, check that we can still access it (synchronously) + // after tensor deletion + float pattern_value = 42.0f; + cudaError_t cuda_err = cudaMemcpy( + cuda_data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, cuda_data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test basic functionality with CPU memory +TEST_F(AOTITorchCreateTensorFromBlobV2Test, BasicFunctionalityCPU) { + // Test 2D tensor + std::vector sizes_2d = {3, 4}; + std::vector strides_2d = calculate_contiguous_strides(sizes_2d); + + // Allocate CPU memory + size_t bytes = calculate_numel(sizes_2d) * sizeof(float); + void* cpu_data = allocate_cpu_memory(bytes); + ASSERT_NE(cpu_data, nullptr); + + Tensor* tensor_2d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cpu_data, + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), + 0, // device index + &tensor_2d, + 0, // layout (strided) + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_2d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_2d->dim(), 2); + EXPECT_EQ(tensor_2d->size(0), 3); + EXPECT_EQ(tensor_2d->size(1), 4); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor_2d->mutable_data_ptr(); + EXPECT_EQ(tensor_data, cpu_data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor_2d); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) For CPU memory, directly write and read to verify accessibility + float* float_ptr = reinterpret_cast(cpu_data); + float pattern_value = 42.0f; + *float_ptr = pattern_value; + EXPECT_EQ(*float_ptr, pattern_value) + << "Original CPU memory should still be accessible after tensor deletion"; +} + +// Test with invalid dtype +TEST_F(AOTITorchCreateTensorFromBlobV2Test, InvalidDtype) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + 999, // invalid dtype + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with non-zero storage offset (should fail since from_blob cannot handle +// offsets) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, NonZeroStorageOffset) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 1, // non-zero storage_offset (should fail since from_blob cannot handle + // offsets) + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with custom strides (using stride parameter but still contiguous) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, CustomContiguousStrides) { + std::vector sizes = {2, 3}; + // Use the correct contiguous strides but pass them explicitly + std::vector contiguous_strides = {3, 1}; // Proper contiguous strides + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + contiguous_strides.data(), // Explicitly pass contiguous strides + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Verify strides were properly set (we can check via aoti_torch_get_strides) + int64_t* tensor_strides; + error = aoti_torch_get_strides(tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 3); + EXPECT_EQ(tensor_strides[1], 1); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test with null data pointer +TEST_F(AOTITorchCreateTensorFromBlobV2Test, NullDataPointer) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + nullptr, // null data pointer + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test scalar tensor (0D) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, ScalarTensor) { + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + + size_t bytes = sizeof(float); // Single element + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 0); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test zero-sized tensor +TEST_F(AOTITorchCreateTensorFromBlobV2Test, ZeroSizedTensor) { + std::vector sizes = {0, 5}; // Zero elements + std::vector strides = calculate_contiguous_strides(sizes); + + // Even for zero-sized tensor, we need some memory allocated + size_t bytes = sizeof(float); // Minimum allocation + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test multi-dimensional tensors +TEST_F(AOTITorchCreateTensorFromBlobV2Test, MultiDimensionalTensors) { + // Test 3D tensor + std::vector sizes_3d = {2, 3, 4}; + std::vector strides_3d = calculate_contiguous_strides(sizes_3d); + + size_t bytes_3d = calculate_numel(sizes_3d) * sizeof(float); + void* data_3d = allocate_cuda_memory(bytes_3d); + ASSERT_NE(data_3d, nullptr); + + Tensor* tensor_3d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data_3d, + sizes_3d.size(), + sizes_3d.data(), + strides_3d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_3d, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_3d, nullptr); + EXPECT_EQ(tensor_3d->dim(), 3); + EXPECT_EQ(tensor_3d->size(0), 2); + EXPECT_EQ(tensor_3d->size(1), 3); + EXPECT_EQ(tensor_3d->size(2), 4); + + // Test 4D tensor + std::vector sizes_4d = {2, 3, 4, 5}; + std::vector strides_4d = calculate_contiguous_strides(sizes_4d); + + size_t bytes_4d = calculate_numel(sizes_4d) * sizeof(float); + void* data_4d = allocate_cuda_memory(bytes_4d); + ASSERT_NE(data_4d, nullptr); + + Tensor* tensor_4d; + error = aoti_torch_create_tensor_from_blob_v2( + data_4d, + sizes_4d.size(), + sizes_4d.data(), + strides_4d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_4d, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_4d, nullptr); + EXPECT_EQ(tensor_4d->dim(), 4); + EXPECT_EQ(tensor_4d->size(0), 2); + EXPECT_EQ(tensor_4d->size(1), 3); + EXPECT_EQ(tensor_4d->size(2), 4); + EXPECT_EQ(tensor_4d->size(3), 5); +} + +// Test tensor data pointer consistency +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DataPointerConsistency) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* original_data = allocate_cuda_memory(bytes); + ASSERT_NE(original_data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + original_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check that the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, original_data); +} + +// Test creating multiple tensors from different blobs +TEST_F(AOTITorchCreateTensorFromBlobV2Test, MultipleTensorsFromBlobs) { + const int num_tensors = 5; + std::vector tensors; + std::vector data_ptrs; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {i + 1, i + 2}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + data_ptrs.push_back(data); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + tensors.push_back(tensor); + + // Verify dimensions + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), i + 1); + EXPECT_EQ(tensor->size(1), i + 2); + + // Verify the tensor uses the correct data pointer + EXPECT_EQ(tensor->mutable_data_ptr(), data); + } + + // Verify all tensors have different data pointers + for (int i = 0; i < num_tensors; i++) { + EXPECT_EQ(tensors[i]->mutable_data_ptr(), data_ptrs[i]); + for (int j = i + 1; j < num_tensors; j++) { + EXPECT_NE(tensors[i]->mutable_data_ptr(), tensors[j]->mutable_data_ptr()); + } + } +} + +// Test deletion of tensor created from blob (should not free the original +// memory) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeletionDoesNotFreeOriginalMemory) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // The original memory should still be valid (we'll free it in teardown) + // We can't easily test if the memory is still valid without risking crashes, + // but the test should pass without issues if memory management is correct +} + +// Test with opaque metadata +TEST_F(AOTITorchCreateTensorFromBlobV2Test, WithOpaqueMetadata) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + // Create some opaque metadata + std::vector metadata = {0x01, 0x02, 0x03, 0x04}; + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + metadata.data(), // opaque_metadata + metadata.size()); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); +} + +// Test stress test with many small tensors from blobs +TEST_F(AOTITorchCreateTensorFromBlobV2Test, StressTestManySmallTensors) { + const int num_tensors = 50; // Reduced for reasonable test time + std::vector tensors; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {1, 1}; // Minimal size + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + if (data == nullptr) { + // Skip if we run out of memory + continue; + } + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + if (error == Error::Ok && tensor != nullptr) { + tensors.push_back(tensor); + + // Verify the tensor uses the correct data pointer + EXPECT_EQ(tensor->mutable_data_ptr(), data); + } + } + + // Delete all created tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} diff --git a/backends/cuda/runtime/shims/utils.h b/backends/cuda/runtime/shims/utils.h index 23943391b50..38e56ca45a1 100644 --- a/backends/cuda/runtime/shims/utils.h +++ b/backends/cuda/runtime/shims/utils.h @@ -14,6 +14,26 @@ #include #include +// CUDA error checking macro +#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + return Error::Internal; \ + } while (0) + +// Kernel launch check macro +#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) + namespace executorch { namespace backends { namespace cuda {