From 6d3a43b1e29c4e065f9ad67d06c179cae0a624f1 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 17 Oct 2025 11:47:19 -0700 Subject: [PATCH] make aoti_torch_empty_strided support creating incontiguous tensor This diff modifies the `aoti_torch_empty_strided` function to support the creation of incontiguous tensors. To achieve it, this diff: 1. update the way to calculate the memory size by using both tensor size and the stride 2. skip stride check in ETensor by adding and checking cmake macro `USE_CUDA_BACKEND` when building with CUDA backend support. we will soon bring the ETensor check back for every backend after migrating to use slimtensor. Differential Revision: [D84938258](https://our.internmc.facebook.com/intern/diff/D84938258/) [ghstack-poisoned] --- backends/cuda/runtime/shims/memory.cpp | 29 ++++- .../tests/test_aoti_torch_empty_strided.cpp | 111 +++++++++++++++++- extension/tensor/CMakeLists.txt | 5 + extension/tensor/targets.bzl | 5 + extension/tensor/tensor_ptr.cpp | 7 +- 5 files changed, 146 insertions(+), 11 deletions(-) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 5d30d3124d9..3c803fe6445 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -210,10 +210,6 @@ AOTITorchError aoti_torch_empty_strided( // This requires us to reserve CUDA memory and put it into a ETensor void* ptr; - int64_t numel = 1; - for (int64_t i = 0; i < ndim; i++) { - numel *= sizes_ptr[i]; - } ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); @@ -223,7 +219,29 @@ AOTITorchError aoti_torch_empty_strided( InvalidArgument, "Invalid element size for dtype: %d", dtype); - int64_t nbytes = numel * element_size; + + // Calculate storage size based on strides, matching PyTorch's behavior + // This is critical when sizes and strides don't match the expected contiguous + // layout Reference: PyTorch's computeStorageNbytes in EmptyTensor.cpp + int64_t storage_size = 1; // storage offset (0) + 1 + for (int64_t i = 0; i < ndim; i++) { + if (sizes_ptr[i] == 0) { + storage_size = 0; + break; + } + // For each dimension, add stride[i] * (size[i] - 1) + // This gives us the maximum offset in that dimension + int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 0; + if (strides_ptr == nullptr) { + // Calculate contiguous stride if not provided + stride_i = 1; + for (int64_t j = i + 1; j < ndim; j++) { + stride_i *= sizes_ptr[j]; + } + } + storage_size += stride_i * (sizes_ptr[i] - 1); + } + int64_t nbytes = storage_size * element_size; if (device_type == static_cast(SupportedDevices::CUDA)) { ET_CUDA_CHECK_OR_RETURN_ERROR( @@ -259,7 +277,6 @@ AOTITorchError aoti_torch_empty_strided( // This tensor owns the memory it allocated, set reference count to 1 memory_to_n_tensor[ptr] = 1; - return Error::Ok; } diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp index da65129f18a..d89065a3b4a 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp @@ -509,11 +509,11 @@ TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) { EXPECT_EQ(sizes_ptr[2], 3); } -// Test different data types (only float32 is currently supported) +// Test different data types (currently we support bf16, fp32 and int32) TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) { std::vector sizes = {2, 3}; - // Test float32 (dtype 6) - currently the only supported type + // Test float32 (dtype 6) - one of the supported types Tensor* tensor_float32; AOTITorchError error = aoti_torch_empty_strided( sizes.size(), @@ -527,7 +527,7 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) { EXPECT_EQ(error, Error::Ok); EXPECT_NE(tensor_float32, nullptr); - // Test unsupported data types should return error + // Test int32 (dtype 3) - one of the supported types Tensor* tensor_int32; error = aoti_torch_empty_strided( sizes.size(), @@ -538,7 +538,8 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) { 0, // device index &tensor_int32); - EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_int32, nullptr); // Test another unsupported data type Tensor* tensor_float64; @@ -586,3 +587,105 @@ TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) { EXPECT_EQ(tensor_5d->size(3), 4); EXPECT_EQ(tensor_5d->size(4), 5); } + +// Test incontiguous tensor creation - transpose-like layout +TEST_F(AOTITorchEmptyStridedTest, IncontiguousTransposeLayout) { + // Create a tensor with transpose-like strides (column-major) + // For a 3x4 tensor in column-major order, strides should be [1, 3] + // This means each row step is 1, and each column step is 3 + std::vector sizes = {3, 4}; + std::vector strides = {1, 3}; // Column-major (incontiguous) + + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + + // Verify the strides are what we specified + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 1); // Column-major stride for dimension 0 + EXPECT_EQ(strides_ptr[1], 3); // Column-major stride for dimension 1 + + // Verify that memory was allocated correctly for incontiguous layout + // Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] - + // 1) + 1 = 1 * (3 - 1) + 3 * (4 - 1) + 1 = 1 * 2 + 3 * 3 + 1 = 2 + 9 + 1 = 12 + // elements Total bytes = 12 * 4 = 48 bytes (for float32) + EXPECT_EQ(tensor->numel(), 12); // numel is still 3*4=12 for logical shape + + // The tensor should be accessible and writable + void* data_ptr = tensor->mutable_data_ptr(); + EXPECT_NE(data_ptr, nullptr); + + // Verify we can use CUDA to write to the memory + std::vector test_data(12, 1.0f); + cudaError_t cuda_err = cudaMemcpy( + data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); +} + +// Test incontiguous tensor creation - expanded/broadcasted stride pattern +TEST_F(AOTITorchEmptyStridedTest, IncontiguousExpandedStrides) { + // Create a tensor with expanded strides (simulating broadcasting) + // A 2x3x4 tensor where the first dimension has stride 0 (expanded) + // This creates a tensor where the first dimension is "broadcasted" + std::vector sizes = {2, 3, 4}; + std::vector strides = {0, 4, 1}; // First dimension has stride 0 + + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + EXPECT_EQ(tensor->size(2), 4); + + // Verify the strides are what we specified + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 0); // Expanded dimension stride + EXPECT_EQ(strides_ptr[1], 4); + EXPECT_EQ(strides_ptr[2], 1); + + // Verify that memory was allocated correctly for this incontiguous layout + // Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] - + // 1) + stride[2] * (size[2] - 1) + 1 = 0 * (2 - 1) + 4 * (3 - 1) + 1 * (4 - + // 1) + 1 = 0 + 8 + 3 + 1 = 12 elements Note: numel() returns logical number + // of elements (2*3*4=24), not storage size + EXPECT_EQ(tensor->numel(), 24); // Logical numel is 2*3*4=24 + + // The tensor should be accessible and writable + void* data_ptr = tensor->mutable_data_ptr(); + EXPECT_NE(data_ptr, nullptr); + + // Verify we can use CUDA to write to the allocated memory + // We only need to allocate 12 elements (storage size), not 24 + std::vector test_data(12, 2.0f); + cudaError_t cuda_err = cudaMemcpy( + data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); +} diff --git a/extension/tensor/CMakeLists.txt b/extension/tensor/CMakeLists.txt index 2a8d9b17916..14502e527d1 100644 --- a/extension/tensor/CMakeLists.txt +++ b/extension/tensor/CMakeLists.txt @@ -24,6 +24,11 @@ target_include_directories( ) target_compile_options(extension_tensor PUBLIC ${_common_compile_options}) +# Define USE_CUDA_BACKEND when building with CUDA backend support +if(EXECUTORCH_BUILD_CUDA) + target_compile_definitions(extension_tensor PUBLIC USE_CUDA_BACKEND) +endif() + # Install libraries install( TARGETS extension_tensor diff --git a/extension/tensor/targets.bzl b/extension/tensor/targets.bzl index bf1485aaba5..0a7cfb82334 100644 --- a/extension/tensor/targets.bzl +++ b/extension/tensor/targets.bzl @@ -10,6 +10,10 @@ def define_common_targets(): for aten_mode in get_aten_mode_options(): aten_suffix = ("_aten" if aten_mode else "") + # Check if USE_CUDA_BACKEND flag is set via build config + use_cuda_backend = native.read_config("executorch", "use_cuda_backend", "false") == "true" + preprocessor_flags = ["-DUSE_CUDA_BACKEND"] if use_cuda_backend else [] + runtime.cxx_library( name = "tensor" + aten_suffix, srcs = [ @@ -25,6 +29,7 @@ def define_common_targets(): visibility = [ "@EXECUTORCH_CLIENTS", ], + preprocessor_flags = preprocessor_flags, deps = [ "//executorch/runtime/core/exec_aten/util:dim_order_util" + aten_suffix, "//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix, diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index dab1a8ab176..ae1823fe0db 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -79,6 +79,11 @@ TensorPtr make_tensor_ptr( }); } } + +// Skip stride calculation and incontiguous tensor check for CUDA backend since +// AOTI-CUDA handles both contiguous and incontiguous tensors. This will be +// removed after SlimTensor migration. +#ifndef USE_CUDA_BACKEND std::vector computed_strides(dim); auto error = runtime::dim_order_to_stride( @@ -98,8 +103,8 @@ TensorPtr make_tensor_ptr( sizes[i]); } } - strides = std::move(computed_strides); +#endif // USE_CUDA_BACKEND #ifndef USE_ATEN_LIB executorch::aten::TensorImpl tensor_impl(