diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 82d30cdb4ef..22734935df2 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -36,6 +36,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { switch (dtype) { case 6: // PyTorch's float32 dtype code return executorch::aten::ScalarType::Float; + case 15: // PyTorch's bfloat16 dtype code + return executorch::aten::ScalarType::BFloat16; // Future support for additional dtypes can be added here default: ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype); diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS new file mode 100644 index 00000000000..1aa38760e5a --- /dev/null +++ b/backends/cuda/runtime/TARGETS @@ -0,0 +1,32 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.cxx_library( + name = "runtime_shims", + srcs = [ + "shims/memory.cpp", + "shims/tensor_attribute.cpp", + ], + headers = [ + "shims/memory.h", + "shims/tensor_attribute.h", + "shims/utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp new file mode 100644 index 00000000000..99d936e32ca --- /dev/null +++ b/backends/cuda/runtime/shims/memory.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include // For posix_memalign +#include +#include +#include + +// CUDA error checking macro +#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + return Error::Internal; \ + } while (0) + +// Kernel launch check macro +#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::aten::SizesType; +using executorch::aten::StridesType; +using executorch::backends::aoti::dtype_to_element_size; +using executorch::backends::aoti::dtype_to_scalar_type; + +// Global storage for tensors and their metadata +std::unordered_set> tensors; + +extern "C" { + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor) { + // Check that device_index is always 0 + if (device_index != 0) { + ET_LOG(Error, "device_index must be 0, got: %d", device_index); + return Error::InvalidArgument; + } + + // This requires us to reserve CUDA memory and put it into a ETensor + void* ptr; + int64_t numel = 1; + for (int64_t i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + AOTITorchError dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + size_t element_size = dtype_to_element_size(dtype); + if (element_size == 0) { + ET_LOG(Error, "Invalid element size for dtype: %d", dtype); + return Error::InvalidArgument; + } + int64_t nbytes = numel * element_size; + + if (device_type == 1) { // cuda + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes)); + } else if (device_type == 0) { // cpu + // Ensure 16-byte alignment for CPU memory to match CUDA requirements + int result = posix_memalign(&ptr, 16, nbytes); + if (result != 0) { + ET_LOG(Error, "Failed to allocate aligned CPU memory"); + return Error::MemoryAllocationFailed; + } + if (ptr == nullptr) { + ET_LOG(Error, "Failed to call posix_memalign"); + return Error::MemoryAllocationFailed; + } + } else { + ET_LOG( + Error, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + return Error::NotImplemented; + } + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // ETensor creation with dynamic shape support for edge cases + auto tensor = executorch::extension::from_blob( + ptr, sizes, strides, dtype_to_scalar_type(dtype)); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + + return Error::Ok; +} + +// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors +void clear_all_tensors() { + tensors.clear(); +} + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h new file mode 100644 index 00000000000..2fdfdd8a72c --- /dev/null +++ b/backends/cuda/runtime/shims/memory.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +extern "C" { + +/** + * Creates an uninitialized tensor with specified dimensions, strides, and + * dtyper on either CPU or CUDA device. + * + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes + * @param strides_ptr Pointer to array of strides for each dimension + * @param dtype Data type identifier (matches PyTorch scalar types) + * @param device_type Device type (0=CPU, 1=CUDA) + * @param device_index Device index (must be 0 for current implementation) + * @param ret_new_tensor Output parameter for the created tensor + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor); + +// Function to clear all tensors from internal storage +// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors +void clear_all_tensors(); + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/tests/TARGETS b/backends/cuda/runtime/shims/tests/TARGETS new file mode 100644 index 00000000000..9ff3e83a8bd --- /dev/null +++ b/backends/cuda/runtime/shims/tests/TARGETS @@ -0,0 +1,6 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl new file mode 100644 index 00000000000..5737bdb00ab --- /dev/null +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -0,0 +1,30 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") + +def cuda_shim_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/backends/cuda/runtime:runtime_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + cuda_shim_cpp_unittest("aoti_torch_empty_strided") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp new file mode 100644 index 00000000000..8e6998f457c --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp @@ -0,0 +1,588 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_empty_strided tests +class AOTITorchEmptyStridedTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_tracked_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test aoti_torch_empty_strided basic functionality +TEST_F(AOTITorchEmptyStridedTest, BasicFunctionality) { + // Test 1D tensor + std::vector sizes_1d = {5}; + Tensor* tensor_1d; + AOTITorchError error = aoti_torch_empty_strided( + sizes_1d.size(), + sizes_1d.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_1d); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_1d, nullptr); + + // CRITICAL: Verify the tensor is actually float32 + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_1d, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::FLOAT32)) + << "Expected float32 dtype (" + << static_cast(SupportedDTypes::FLOAT32) << "), got " + << actual_dtype; + + // Verify element size (float32 should be 4 bytes per element) + size_t element_size = tensor_1d->element_size(); + EXPECT_EQ(element_size, 4) + << "Expected float32 element size to be 4 bytes, got " << element_size; + + // Verify total number of elements and memory usage + int64_t expected_numel = 5; // 5 elements + EXPECT_EQ(tensor_1d->numel(), expected_numel) + << "Expected " << expected_numel << " elements, got " + << tensor_1d->numel(); + + // Verify total memory size (numel * element_size) + size_t expected_memory_size = expected_numel * 4; // 5 * 4 = 20 bytes + size_t actual_memory_size = tensor_1d->numel() * tensor_1d->element_size(); + EXPECT_EQ(actual_memory_size, expected_memory_size) + << "Expected " << expected_memory_size << " bytes, got " + << actual_memory_size; + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + + // Test 2D tensor with explicit strides + std::vector sizes_2d = {3, 4}; + std::vector strides_2d = {4, 1}; + Tensor* tensor_2d; + error = aoti_torch_empty_strided( + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_2d); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_2d, nullptr); + + // Verify 2D tensor is also float32 + int32_t dtype_2d; + EXPECT_EQ(aoti_torch_get_dtype(tensor_2d, &dtype_2d), Error::Ok); + EXPECT_EQ(dtype_2d, static_cast(SupportedDTypes::FLOAT32)) + << "Expected float32 dtype (" + << static_cast(SupportedDTypes::FLOAT32) << "), got " + << dtype_2d; + + // Verify element size for 2D tensor + EXPECT_EQ(tensor_2d->element_size(), 4); + + // Check tensor properties + EXPECT_EQ(tensor_2d->dim(), 2); + EXPECT_EQ(tensor_2d->size(0), 3); + EXPECT_EQ(tensor_2d->size(1), 4); + + // Verify memory size for 2D tensor + int64_t expected_numel_2d = 3 * 4; // 12 elements + size_t expected_memory_2d = expected_numel_2d * 4; // 12 * 4 = 48 bytes + EXPECT_EQ(tensor_2d->numel() * tensor_2d->element_size(), expected_memory_2d); +} + +// Test aoti_torch_empty_strided with CPU device +TEST_F(AOTITorchEmptyStridedTest, CPUDevice) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); +} + +// Test aoti_torch_empty_strided with invalid dtype +TEST_F(AOTITorchEmptyStridedTest, InvalidDtype) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 999, // invalid dtype + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test aoti_torch_empty_strided with unsupported device +TEST_F(AOTITorchEmptyStridedTest, UnsupportedDevice) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 2, // unsupported device type + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::NotImplemented); +} + +// Test aoti_torch_empty_strided with zero-sized tensor +TEST_F(AOTITorchEmptyStridedTest, ZeroSized) { + std::vector sizes = {0, 5}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); +} + +// Test aoti_torch_empty_strided scalar tensor (0D) +TEST_F(AOTITorchEmptyStridedTest, Scalar) { + std::vector sizes = {}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 0); +} + +// Test aoti_torch_empty_strided with large tensor +TEST_F(AOTITorchEmptyStridedTest, LargeTensor) { + std::vector sizes = {100, 200, 50}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 100); + EXPECT_EQ(tensor->size(1), 200); + EXPECT_EQ(tensor->size(2), 50); +} + +// Test error handling with memory allocation failures +TEST_F(AOTITorchEmptyStridedTest, MemoryAllocationStress) { + // Try to create a very large tensor that might cause allocation failure + // (This test may pass or fail depending on available memory) + std::vector huge_sizes = {10000, 10000, 100}; // ~38GB for float32 + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + huge_sizes.size(), + huge_sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + // Either succeed or fail with memory allocation error + if (error == Error::Ok) { + EXPECT_NE(tensor, nullptr); + } else { + EXPECT_EQ(error, Error::MemoryAllocationFailed); + } +} + +// Test aoti_torch_empty_strided with bfloat16 dtype +TEST_F(AOTITorchEmptyStridedTest, BFloat16Tensor) { + // Test creating bfloat16 tensor on CUDA + std::vector sizes = {2, 3, 4}; + Tensor* tensor_bf16; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16, nullptr); + + // CRITICAL: Verify the tensor is actually bfloat16 + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << actual_dtype; + + // Verify element size (bfloat16 should be 2 bytes per element) + size_t element_size = tensor_bf16->element_size(); + EXPECT_EQ(element_size, 2) + << "Expected bfloat16 element size to be 2 bytes, got " << element_size; + + // Verify total number of elements and memory usage + int64_t expected_numel = 2 * 3 * 4; // 24 elements + EXPECT_EQ(tensor_bf16->numel(), expected_numel) + << "Expected " << expected_numel << " elements, got " + << tensor_bf16->numel(); + + // Verify total memory size (numel * element_size) + size_t expected_memory_size = expected_numel * 2; // 24 * 2 = 48 bytes + size_t actual_memory_size = + tensor_bf16->numel() * tensor_bf16->element_size(); + EXPECT_EQ(actual_memory_size, expected_memory_size) + << "Expected " << expected_memory_size << " bytes, got " + << actual_memory_size; + + // Check tensor properties + EXPECT_EQ(tensor_bf16->dim(), 3); + EXPECT_EQ(tensor_bf16->size(0), 2); + EXPECT_EQ(tensor_bf16->size(1), 3); + EXPECT_EQ(tensor_bf16->size(2), 4); + + // Verify we can get tensor metadata + int64_t* sizes_ptr; + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_sizes(tensor_bf16, &sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor_bf16, &strides_ptr), Error::Ok); + + // Check sizes match + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 3); + EXPECT_EQ(sizes_ptr[2], 4); + + // Check that strides are computed correctly (row-major order) + EXPECT_EQ(strides_ptr[0], 12); // 3 * 4 + EXPECT_EQ(strides_ptr[1], 4); // 4 + EXPECT_EQ(strides_ptr[2], 1); // 1 + + // Test bfloat16 tensor with custom strides + std::vector sizes_2d = {3, 2}; + std::vector strides_2d = {2, 1}; // Row-major strides + Tensor* tensor_bf16_custom; + error = aoti_torch_empty_strided( + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16_custom); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16_custom, nullptr); + + // Verify custom stride tensor is also bfloat16 + int32_t custom_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_custom, &custom_dtype), Error::Ok); + EXPECT_EQ(custom_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << custom_dtype; + + // Verify element size for custom stride tensor + EXPECT_EQ(tensor_bf16_custom->element_size(), 2); + + // Check tensor properties + EXPECT_EQ(tensor_bf16_custom->dim(), 2); + EXPECT_EQ(tensor_bf16_custom->size(0), 3); + EXPECT_EQ(tensor_bf16_custom->size(1), 2); + + // Verify memory size for custom stride tensor + int64_t custom_expected_numel = 3 * 2; // 6 elements + size_t custom_expected_memory = custom_expected_numel * 2; // 6 * 2 = 12 bytes + EXPECT_EQ( + tensor_bf16_custom->numel() * tensor_bf16_custom->element_size(), + custom_expected_memory); + + // Check custom strides + int64_t* custom_strides_ptr; + EXPECT_EQ( + aoti_torch_get_strides(tensor_bf16_custom, &custom_strides_ptr), + Error::Ok); + EXPECT_EQ(custom_strides_ptr[0], 2); + EXPECT_EQ(custom_strides_ptr[1], 1); + + // Test bfloat16 scalar tensor (0D) + std::vector scalar_sizes = {}; + Tensor* tensor_bf16_scalar; + error = aoti_torch_empty_strided( + scalar_sizes.size(), + scalar_sizes.data(), + nullptr, + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16_scalar); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16_scalar, nullptr); + EXPECT_EQ(tensor_bf16_scalar->dim(), 0); + + // Verify scalar tensor is also bfloat16 + int32_t scalar_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_scalar, &scalar_dtype), Error::Ok); + EXPECT_EQ(scalar_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << scalar_dtype; + + // Verify scalar tensor properties + EXPECT_EQ(tensor_bf16_scalar->element_size(), 2); + EXPECT_EQ(tensor_bf16_scalar->numel(), 1); // Scalar tensor has 1 element + EXPECT_EQ( + tensor_bf16_scalar->numel() * tensor_bf16_scalar->element_size(), + 2); // 1 * 2 = 2 bytes +} + +// Test custom strides functionality +TEST_F(AOTITorchEmptyStridedTest, CustomStrides) { + // Create tensor with valid custom strides (contiguous layout) + std::vector sizes = {2, 3}; + std::vector strides = {3, 1}; // Standard row-major strides + + Tensor* tensor = create_tracked_tensor(sizes, strides); + EXPECT_NE(tensor, nullptr); + + // Verify the tensor was created correctly + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Check strides through AOTI interface + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 3); + EXPECT_EQ(strides_ptr[1], 1); + + // Test another valid stride pattern - transpose-like + std::vector sizes_2 = {3, 2}; + std::vector strides_2 = {1, 3}; // Column-major strides + + Tensor* tensor_2 = create_tracked_tensor(sizes_2, strides_2); + EXPECT_NE(tensor_2, nullptr); + + // Verify the tensor properties + EXPECT_EQ(tensor_2->dim(), 2); + EXPECT_EQ(tensor_2->size(0), 3); + EXPECT_EQ(tensor_2->size(1), 2); + + // Check strides + int64_t* strides_ptr_2; + EXPECT_EQ(aoti_torch_get_strides(tensor_2, &strides_ptr_2), Error::Ok); + EXPECT_EQ(strides_ptr_2[0], 1); + EXPECT_EQ(strides_ptr_2[1], 3); +} + +// Test edge case: zero-element tensor with non-zero dimensions +TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) { + std::vector sizes = {2, 0, 3}; // Total elements = 0 + Tensor* tensor = create_tracked_tensor(sizes); + EXPECT_NE(tensor, nullptr); + + // Verify the tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 0); + EXPECT_EQ(tensor->size(2), 3); + + // Should be able to get metadata + int64_t* sizes_ptr; + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 0); + EXPECT_EQ(sizes_ptr[2], 3); +} + +// Test different data types (only float32 is currently supported) +TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) { + std::vector sizes = {2, 3}; + + // Test float32 (dtype 6) - currently the only supported type + Tensor* tensor_float32; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor_float32); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_float32, nullptr); + + // Test unsupported data types should return error + Tensor* tensor_int32; + error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 3, // int32 - unsupported + 1, // CUDA device + 0, // device index + &tensor_int32); + + EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype + + // Test another unsupported data type + Tensor* tensor_float64; + error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 7, // float64 - unsupported + 1, // CUDA device + 0, // device index + &tensor_float64); + + EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype +} + +// Test multi-dimensional tensors with various shapes +TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) { + // Test 3D tensor + std::vector sizes_3d = {2, 3, 4}; + Tensor* tensor_3d = create_tracked_tensor(sizes_3d); + EXPECT_NE(tensor_3d, nullptr); + EXPECT_EQ(tensor_3d->dim(), 3); + EXPECT_EQ(tensor_3d->size(0), 2); + EXPECT_EQ(tensor_3d->size(1), 3); + EXPECT_EQ(tensor_3d->size(2), 4); + + // Test 4D tensor + std::vector sizes_4d = {2, 3, 4, 5}; + Tensor* tensor_4d = create_tracked_tensor(sizes_4d); + EXPECT_NE(tensor_4d, nullptr); + EXPECT_EQ(tensor_4d->dim(), 4); + EXPECT_EQ(tensor_4d->size(0), 2); + EXPECT_EQ(tensor_4d->size(1), 3); + EXPECT_EQ(tensor_4d->size(2), 4); + EXPECT_EQ(tensor_4d->size(3), 5); + + // Test 5D tensor + std::vector sizes_5d = {1, 2, 3, 4, 5}; + Tensor* tensor_5d = create_tracked_tensor(sizes_5d); + EXPECT_NE(tensor_5d, nullptr); + EXPECT_EQ(tensor_5d->dim(), 5); + EXPECT_EQ(tensor_5d->size(0), 1); + EXPECT_EQ(tensor_5d->size(1), 2); + EXPECT_EQ(tensor_5d->size(2), 3); + EXPECT_EQ(tensor_5d->size(3), 4); + EXPECT_EQ(tensor_5d->size(4), 5); +} diff --git a/backends/cuda/runtime/shims/utils.h b/backends/cuda/runtime/shims/utils.h new file mode 100644 index 00000000000..23943391b50 --- /dev/null +++ b/backends/cuda/runtime/shims/utils.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +// Enum for supported data types in et-cuda backend +enum class SupportedDTypes : int32_t { + FLOAT32 = 6, // PyTorch's float32 dtype code + BFLOAT16 = 15, // PyTorch's bfloat16 dtype code +}; + +// Enum for supported device types in et-cuda backend +enum class SupportedDevices : int32_t { + CPU = 0, // CPU device + CUDA = 1, // CUDA device +}; + +// Utility function to convert sizes pointer to vector +inline std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +inline std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. it is ok if provided strides here is not contiguous + // strides since it will be used internally in CUDA delegate. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes using ExecutorTorch's algorithm + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +extern "C" { +using executorch::runtime::Error; +// Common AOTI type aliases +using AOTITorchError = Error; + +// Helper function to check if a dtype is supported in ET CUDA backend +inline bool is_dtype_supported_in_et_cuda(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Dtype validation utility function +inline AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_cuda(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch