diff --git a/kernels/portable/cpu/op_abs.cpp b/kernels/portable/cpu/op_abs.cpp index 0dd925a0e25..9c2c219832d 100644 --- a/kernels/portable/cpu/op_abs.cpp +++ b/kernels/portable/cpu/op_abs.cpp @@ -28,6 +28,8 @@ Tensor& abs_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { "Failed to resize output tensor."); ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] { apply_unary_map_fn( diff --git a/kernels/test/TestUtil.h b/kernels/test/TestUtil.h index ed72dbc4128..8d782d3c2a9 100644 --- a/kernels/test/TestUtil.h +++ b/kernels/test/TestUtil.h @@ -30,6 +30,22 @@ #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \ EXPECT_ANY_THROW(_statement) +#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ + tf, op, input_contiguous, expected_contiguous, channels_last_support) \ + Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ + Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ + \ + Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ + Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ + \ + Tensor ret = op(input_channels_last, output_channels_last); \ + if (channels_last_support) { \ + EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ + } else { \ + EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ + } \ + EXPECT_TENSOR_EQ(output_channels_last, ret); + #else #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ @@ -52,6 +68,26 @@ } \ } while (false) +#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ + tf, op, input_contiguous, expected_contiguous, channels_last_support) \ + Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ + Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ + \ + Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ + Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ + \ + Tensor ret = op(input_channels_last, output_channels_last); \ + if (channels_last_support) { \ + EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ + } else { \ + EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ + } \ + EXPECT_TENSOR_EQ(output_channels_last, ret); \ + ET_EXPECT_KERNEL_FAILURE( \ + context_, op(input_channels_last, output_contiguous)); \ + ET_EXPECT_KERNEL_FAILURE( \ + context_, op(input_contiguous, output_channels_last)); + #endif // USE_ATEN_LIB /* diff --git a/kernels/test/op_abs_test.cpp b/kernels/test/op_abs_test.cpp index b54cd971567..f596d586d90 100644 --- a/kernels/test/op_abs_test.cpp +++ b/kernels/test/op_abs_test.cpp @@ -38,3 +38,28 @@ TEST_F(OpAbsTest, SanityCheck) { EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, expected); } + +TEST_F(OpAbsTest, MemoryFormatCheck) { + TensorFactory tf; + + std::vector sizes = {2, 3, 1, 5}; + + Tensor input_contiguous = + tf.make(sizes, {0.8737, 0.5359, 0.3743, -0.3040, -0.7800, -0.2306, + -0.7684, -0.5364, 0.3478, -0.3289, 0.0829, 0.2939, + -0.8211, 0.8572, -0.0802, 0.9252, -0.2093, 0.9013, + -0.4197, 0.3987, -0.5291, -0.5567, 0.2691, 0.7819, + -0.8009, -0.4286, -0.9299, 0.2143, 0.2565, -0.5701}); + Tensor expected_contiguous = tf.make( + sizes, {0.8737, 0.5359, 0.3743, 0.3040, 0.7800, 0.2306, 0.7684, 0.5364, + 0.3478, 0.3289, 0.0829, 0.2939, 0.8211, 0.8572, 0.0802, 0.9252, + 0.2093, 0.9013, 0.4197, 0.3987, 0.5291, 0.5567, 0.2691, 0.7819, + 0.8009, 0.4286, 0.9299, 0.2143, 0.2565, 0.5701}); + + ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( + tf, + op_abs_out, + input_contiguous, + expected_contiguous, + /*channels_last_support=*/true); +} diff --git a/runtime/core/exec_aten/testing_util/tensor_factory.h b/runtime/core/exec_aten/testing_util/tensor_factory.h index 8f39cc9911d..3045af55819 100644 --- a/runtime/core/exec_aten/testing_util/tensor_factory.h +++ b/runtime/core/exec_aten/testing_util/tensor_factory.h @@ -3,8 +3,10 @@ #pragma once #include +#include #include +#include #include #include #include @@ -54,7 +56,7 @@ inline size_t sizes_to_numel(const std::vector& sizes) { inline bool check_strides( const std::vector sizes, - const std::vector strides) { + const std::vector strides) { if (sizes.size() != strides.size()) { // The length of stride vector shall equal to size vector. return false; @@ -147,14 +149,14 @@ inline bool check_dim_order( return true; } -inline std::vector strides_from_dim_order( +inline std::vector strides_from_dim_order( const std::vector& sizes, const std::vector& dim_order) { bool legal = check_dim_order(sizes, dim_order); ET_CHECK_MSG(legal, "The input dim_order variable is illegal."); size_t ndim = sizes.size(); - std::vector strides(ndim); + std::vector strides(ndim); strides[dim_order[ndim - 1]] = 1; for (int i = ndim - 2; i >= 0; --i) { uint8_t cur_dim = dim_order[i]; @@ -258,7 +260,7 @@ class TensorFactory { at::Tensor make( const std::vector& sizes, const std::vector& data, - const std::vector strides = {}, + const std::vector strides = {}, ET_UNUSED TensorShapeDynamism dynamism = TensorShapeDynamism::DYNAMIC_UNBOUND) { auto expected_numel = internal::sizes_to_numel(sizes); @@ -344,6 +346,72 @@ class TensorFactory { sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism); } + /** + * Given data in contiguous memory format, returns a new Tensor with the + * specified shape and the same data but in channels last memory format. + * + * @param[in] sizes The sizes of the dimensions of the Tensor. + * @param[in] data The data in contiguous memory format that the Tensor should + * be initialized with. The size of this vector must be equal to the product + * of the elements of `sizes`. + * + * @return A new Tensor with the specified shape and data in channls last + * memory format. + */ + at::Tensor channels_last_like( + const at::Tensor& input, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + ET_CHECK_MSG( + input.sizes().size() == 4, "Only 4D tensors can be channels last"); + + const std::vector sizes( + input.sizes().begin(), input.sizes().end()); + + std::vector contiguous_dim_order(sizes.size()); + for (uint8_t i = 0; i < sizes.size(); i++) { + contiguous_dim_order[i] = i; + } + std::vector contiguous_strides = + internal::strides_from_dim_order(sizes, contiguous_dim_order); + + for (int32_t i = 0; i < input.dim(); i++) { + ET_CHECK_MSG( + input.strides()[i] == contiguous_strides[i], + "Input tensor is not contiguous"); + } + + int32_t N = sizes[0]; + int32_t C = sizes[1]; + int32_t H = sizes[2]; + int32_t W = sizes[3]; + + std::vector contiguous_data( + input.data_ptr(), input.data_ptr() + input.numel()); + std::vector channels_last_data( + N * C * H * W); // Create a new blob with the same total size to contain + // channels_last data + for (int32_t n = 0; n < N; ++n) { + for (int32_t c = 0; c < C; ++c) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + // Calculate the index in the original blob + int32_t old_index = ((n * C + c) * H + h) * W + w; + // Calculate the index in the new blob + int32_t new_index = ((n * H + h) * W + w) * C + c; + // Copy the data + channels_last_data[new_index] = contiguous_data[old_index]; + } + } + } + } + + return make_with_dimorder( + sizes, + channels_last_data, + internal::channels_last_dim_order(sizes.size()), + dynamism); + } + /** * Returns a new Tensor with the specified shape, containing contiguous * data will all elements set to `value`. @@ -459,14 +527,13 @@ class TensorFactory { */ at::Tensor empty_strided( const std::vector& sizes, - const std::vector& strides, + const std::vector& strides, ET_UNUSED TensorShapeDynamism dynamism = TensorShapeDynamism::DYNAMIC_UNBOUND) { auto sizes64 = vec_32_to_64(sizes); - auto strides64 = vec_32_to_64(strides); return at::empty_strided( sizes64, - strides64, + strides, DTYPE, /*layout_opt=*/at::Layout::Strided, /*device_opt=*/at::Device(at::DeviceType::CPU), @@ -666,7 +733,7 @@ class TensorFactory { torch::executor::Tensor make( const std::vector& sizes, const std::vector& data, - const std::vector strides = {}, + const std::vector strides = {}, TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { std::vector default_strides; // Generate strides from the tensor dimensions, assuming contiguous data if @@ -746,7 +813,7 @@ class TensorFactory { /** * Returns a new Tensor with the specified shape and data in channels last - * memory layout. + * memory format. * * @param[in] sizes The sizes of the dimensions of the Tensor. * @param[in] data The data that the Tensor should be initialized with. The @@ -764,6 +831,60 @@ class TensorFactory { sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism); } + /** + * Given data in contiguous memory format, returns a new Tensor with the + * specified shape and the same data but in channels last memory format. + * + * @param[in] sizes The sizes of the dimensions of the Tensor. + * @param[in] data The data in contiguous memory format that the Tensor should + * be initialized with. The size of this vector must be equal to the product + * of the elements of `sizes`. + * + * @return A new Tensor with the specified shape and data in channls last + * memory format. + */ + torch::executor::Tensor channels_last_like( + const torch::executor::Tensor& input, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + const std::vector sizes( + input.sizes().begin(), input.sizes().end()); + + ET_CHECK_MSG(sizes.size() == 4, "Only 4D tensors can be channels last"); + ET_CHECK_MSG( + is_contiguous_dim_order(input.dim_order().data(), input.dim()) == true, + "Input tensor is not contiguous"); + int32_t N = sizes[0]; + int32_t C = sizes[1]; + int32_t H = sizes[2]; + int32_t W = sizes[3]; + + std::vector contiguous_data( + input.data_ptr(), input.data_ptr() + input.numel()); + std::vector channels_last_data( + N * C * H * W); // Create a new blob with the same total size to contain + // channels_last data + for (int32_t n = 0; n < N; ++n) { + for (int32_t c = 0; c < C; ++c) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + // Calculate the index in the original blob + int32_t old_index = ((n * C + c) * H + h) * W + w; + // Calculate the index in the new blob + int32_t new_index = ((n * H + h) * W + w) * C + c; + // Copy the data + channels_last_data[new_index] = contiguous_data[old_index]; + } + } + } + } + + return make_with_dimorder( + sizes, + channels_last_data, + internal::channels_last_dim_order(sizes.size()), + dynamism); + } + /** * Returns a new Tensor with the specified shape, containing contiguous data * will all elements set to `value`. @@ -799,7 +920,20 @@ class TensorFactory { /** * Returns a new Tensor with the specified shape, containing contiguous data - * with all `0` elements. + * in channels last memory format with all `0` elements. + * + * @param[in] sizes The sizes of the dimensions of the Tensor. + * @return A new Tensor with the specified shape. + */ + torch::executor::Tensor zeros_channels_last( + const std::vector& sizes, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + return full_channels_last(sizes, 0, dynamism); + } + + /** + * Returns a new Tensor with the specified shape, containing contiguous data + * in contiguous memory format with all `0` elements. * * @param[in] sizes The sizes of the dimensions of the Tensor. * @return A new Tensor with the specified shape. @@ -878,7 +1012,7 @@ class TensorFactory { std::vector sizes_; std::vector data_; std::vector dim_order_; - std::vector strides_; + std::vector strides_; torch::executor::TensorImpl impl_; }; diff --git a/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp b/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp index a2bc36f4814..8681e9553a6 100644 --- a/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp +++ b/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp @@ -449,7 +449,7 @@ TEST_F(TensorFactoryTest, MakeStridedDataIsCopied) { // Create two tensors using the same input data and strided vector. std::vector data = {1, 2, 3, 4}; - std::vector strides = {1, 2}; + std::vector strides = {1, 2}; Tensor t1 = tf.make(/*sizes=*/{2, 2}, data, strides); Tensor t2 = tf.make(/*sizes=*/{2, 2}, data, strides); diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index b18cd349a62..4dcb0ef9f69 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -235,8 +235,9 @@ */ #define ET_CHECK_CONTIGUOUS(a__) \ ({ \ - const ::exec_aten::ArrayRef strides = a__.strides(); \ - const ::exec_aten::ArrayRef sizes = a__.sizes(); \ + const ::exec_aten::ArrayRef strides = \ + a__.strides(); \ + const ::exec_aten::ArrayRef sizes = a__.sizes(); \ ET_CHECK_MSG( \ strides[strides.size() - 1] == 1, \ "The stride of the last dimension shall be 1 for contiguous tensor, " \ @@ -267,8 +268,10 @@ "Two tensors shall have same number of strides, but not %zu and %zu.", \ a__.dim(), \ b__.dim()); \ - const ::exec_aten::ArrayRef a_strides = a__.strides(); \ - const ::exec_aten::ArrayRef b_strides = b__.strides(); \ + const ::exec_aten::ArrayRef a_strides = \ + a__.strides(); \ + const ::exec_aten::ArrayRef b_strides = \ + b__.strides(); \ for (size_t i = 0; i < a__.dim(); i++) { \ ET_CHECK_MSG( \ a_strides[i] == b_strides[i], \ @@ -276,8 +279,8 @@ "but now is %d and %d.", \ i, \ i, \ - a_strides[i], \ - b_strides[i]); \ + (int32_t)a_strides[i], \ + (int32_t)b_strides[i]); \ } \ }) @@ -295,9 +298,12 @@ a__.dim(), \ b__.dim(), \ c__.dim()); \ - const ::exec_aten::ArrayRef a_strides = a__.strides(); \ - const ::exec_aten::ArrayRef b_strides = b__.strides(); \ - const ::exec_aten::ArrayRef c_strides = c__.strides(); \ + const ::exec_aten::ArrayRef a_strides = \ + a__.strides(); \ + const ::exec_aten::ArrayRef b_strides = \ + b__.strides(); \ + const ::exec_aten::ArrayRef c_strides = \ + c__.strides(); \ for (size_t i = 0; i < a__.dim(); i++) { \ ET_CHECK_MSG( \ a_strides[i] == b_strides[i] && b_strides[i] == c_strides[i], \ @@ -306,9 +312,9 @@ i, \ i, \ i, \ - a_strides[i], \ - b_strides[i], \ - c_strides[i]); \ + (int32_t)a_strides[i], \ + (int32_t)b_strides[i], \ + (int32_t)c_strides[i]); \ } \ }) @@ -848,11 +854,11 @@ inline bool tensor_is_scalar(exec_aten::Tensor t) { /** * The expected output size may not be the existing size of any inputs and - * outputs if the operator supports both broadcast and dynamic shape. Therefore - * such operators needs extra space to store the calculated expected output - * size. such dynamic allocation is troublesome in executorch so we can just - * hard code a static value of a relatively small value because users don't - * create high dimensional tensors. + * outputs if the operator supports both broadcast and dynamic shape. + * Therefore such operators needs extra space to store the calculated expected + * output size. such dynamic allocation is troublesome in executorch so we can + * just hard code a static value of a relatively small value because users + * don't create high dimensional tensors. */ constexpr size_t kTensorDimensionLimit = 16; @@ -893,8 +899,8 @@ inline size_t getTrailingDims(const exec_aten::Tensor& tensor, int64_t dim) { * @param[in] tensor The tensor that will be indexed * @param[in] coordinate A n-dimensional array representing the coordinate to * index. It is assumed that the array has kTensorDimensionLimit elements. - * @param[out] index The linear index to element at the specified coordinate in - * the tensor. + * @param[out] index The linear index to element at the specified coordinate + * in the tensor. */ inline size_t coordinateToIndex( const exec_aten::Tensor& tensor, @@ -935,10 +941,10 @@ inline void indexToCoordinate( * * @param[in] tensor The source of the value to extract. * @param[out] out_val The extracted value, on success. - * @returns `true` if a value was extracted, and sets `*out_val` to that value. - * `false` if a value could not be extracted: either it was not an integer - * Scalar Tensor, or the value of that Scalar Tensor could not be represented - * by INT_T. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted: either it was not an + * integer Scalar Tensor, or the value of that Scalar Tensor could not be + * represented by INT_T. */ template < typename INT_T, @@ -973,10 +979,10 @@ bool extract_scalar_tensor(exec_aten::Tensor tensor, INT_T* out_val) { * * @param[in] tensor The source of the value to extract. * @param[out] out_val The extracted value, on success. - * @returns `true` if a value was extracted, and sets `*out_val` to that value. - * `false` if a value could not be extracted: either it was not a floating - * point Scalar Tensor, or the value of that Scalar Tensor could not be - * represented by FLOAT_T. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted: either it was not a + * floating point Scalar Tensor, or the value of that Scalar Tensor could not + * be represented by FLOAT_T. */ template < typename FLOAT_T, @@ -1076,9 +1082,9 @@ ET_NODISCARD Error resize_tensor_impl( * expand the tensor if new size exceeds the current capacity. Currently * fails an ET_CHECK if the tensor cannot be resized. * - * WARNING: Placeholder API until discussion around runtime context is settled, - * will likely move to be a class method on a TensorResizer object passed in - * through runtimeContext. + * WARNING: Placeholder API until discussion around runtime context is + * settled, will likely move to be a class method on a TensorResizer object + * passed in through runtimeContext. */ ET_NODISCARD inline Error resize_tensor( exec_aten::Tensor t, @@ -1091,9 +1097,9 @@ ET_NODISCARD inline Error resize_tensor( * expand the tensor if new size exceeds the current capacity. Currently * fails an ET_CHECK if the tensor cannot be resized. * - * WARNING: Placeholder API until discussion around runtime context is settled, - * will likely move to be a class method on a TensorResizer object passed in - * through runtimeContext. + * WARNING: Placeholder API until discussion around runtime context is + * settled, will likely move to be a class method on a TensorResizer object + * passed in through runtimeContext. */ template < typename T, @@ -1124,8 +1130,8 @@ ET_DEPRECATED inline void resize( /** * Get dim_order of a Tensor and write it to out_dim_order. * @param tensor The tensor where we want to get dim order from. - * @param out_dim_order Pointing to an array of DimOrderType where we write dim - * order into it. + * @param out_dim_order Pointing to an array of DimOrderType where we write + * dim order into it. * @param out_dim_order_size Size of the DimOrderType array. */ ET_NODISCARD Error get_dim_order( @@ -1134,18 +1140,47 @@ ET_NODISCARD Error get_dim_order( size_t out_dim_order_size); /** - * Checks whether a tensor has a valid dim order. If the dim order could not be - * determined, then this function returns false by default. + * Checks whether a tensor has a valid dim order. If the dim order could not + * be determined, then this function returns false by default. */ bool tensor_has_valid_dim_order(exec_aten::Tensor t); /** - * Checks whether a tensor has either the default of channels last dim order. If - * the dim order could not be determined, then this function returns false by - * default. + * Checks whether a tensor has either the default of channels last dim order. + * If the dim order could not be determined, then this function returns false + * by default. */ bool tensor_is_default_or_channels_last_dim_order(exec_aten::Tensor t); +/** + * Asserts that two tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. Also this macro does not support ATen mode since we do not + * support dim order in ATen mode. + * + * TODO(T183094318): Add dim order and related function support for ATen mode. + */ + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b); + +/** + * Asserts that three tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. Also this macro does not support ATen mode since we do not + * support dim order in ATen mode. + * + * TODO(T183094318): Add dim order and related function support for ATen mode. + */ + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b, + const exec_aten::Tensor& c); + /** * Given an n-dimensional coordinate array and an array of tensor strides, * calculates the linear index that can be used to retrieve the value at the @@ -1205,6 +1240,7 @@ using ::executorch::runtime::tensor_is_real_type; using ::executorch::runtime::tensor_is_realh_type; using ::executorch::runtime::tensor_is_realhb_type; using ::executorch::runtime::tensor_is_scalar; +using ::executorch::runtime::tensors_have_same_dim_order; using ::executorch::runtime::tensors_have_same_dtype; using ::executorch::runtime::tensors_have_same_rank; using ::executorch::runtime::tensors_have_same_shape; diff --git a/runtime/core/exec_aten/util/tensor_util_aten.cpp b/runtime/core/exec_aten/util/tensor_util_aten.cpp index c5ff3b52234..91b75c06483 100644 --- a/runtime/core/exec_aten/util/tensor_util_aten.cpp +++ b/runtime/core/exec_aten/util/tensor_util_aten.cpp @@ -77,6 +77,64 @@ inline bool tensor_is_default_or_channels_last_dim_order(at::Tensor t) { return ret_val; } +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b) { + exec_aten::DimOrderType a_dim_order[kTensorDimensionLimit]; + exec_aten::DimOrderType b_dim_order[kTensorDimensionLimit]; + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(a, a_dim_order, a.dim()) == Error::Ok, + "Failed to retrieve dim order from first input tensor!"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(b, b_dim_order, b.dim()) == Error::Ok, + "Failed to retrieve dim order from second input tensor!"); + + bool all_contiguous = is_contiguous_dim_order(a_dim_order, a.dim()) && + is_contiguous_dim_order(b_dim_order, b.dim()); + + bool all_channels_last = is_channels_last_dim_order(a_dim_order, a.dim()) && + is_channels_last_dim_order(b_dim_order, b.dim()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Two input tensors have different dim orders"); + + return true; +} + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b, + const exec_aten::Tensor& c) { + exec_aten::DimOrderType a_dim_order[kTensorDimensionLimit]; + exec_aten::DimOrderType b_dim_order[kTensorDimensionLimit]; + exec_aten::DimOrderType c_dim_order[kTensorDimensionLimit]; + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(a, a_dim_order, a.dim()) == Error::Ok, + "Failed to retrieve dim order from first input tensor!"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(b, b_dim_order, b.dim()) == Error::Ok, + "Failed to retrieve dim order from second input tensor!"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(c, c_dim_order, c.dim()) == Error::Ok, + "Failed to retrieve dim order from third input tensor!"); + + bool all_contiguous = is_contiguous_dim_order(a_dim_order, a.dim()) && + is_contiguous_dim_order(b_dim_order, b.dim()) && + is_contiguous_dim_order(c_dim_order, c.dim()); + + bool all_channels_last = is_channels_last_dim_order(a_dim_order, a.dim()) && + is_channels_last_dim_order(b_dim_order, b.dim()) && + is_channels_last_dim_order(c_dim_order, c.dim()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Three input tensors have different dim orders"); + + return true; +} + namespace internal { Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) { diff --git a/runtime/core/exec_aten/util/tensor_util_portable.cpp b/runtime/core/exec_aten/util/tensor_util_portable.cpp index c7872d1499a..7e9a15f09a9 100644 --- a/runtime/core/exec_aten/util/tensor_util_portable.cpp +++ b/runtime/core/exec_aten/util/tensor_util_portable.cpp @@ -73,6 +73,40 @@ bool tensor_is_default_or_channels_last_dim_order(torch::executor::Tensor t) { return ret_val; } +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b) { + bool all_contiguous = + is_contiguous_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_contiguous_dim_order(b.dim_order().data(), b.dim_order().size()); + bool all_channels_last = + is_channels_last_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_channels_last_dim_order(b.dim_order().data(), b.dim_order().size()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Two input tensors have different dim orders"); + + return true; +} + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b, + const exec_aten::Tensor& c) { + bool all_contiguous = + is_contiguous_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_contiguous_dim_order(b.dim_order().data(), b.dim_order().size()) && + is_contiguous_dim_order(c.dim_order().data(), c.dim_order().size()); + bool all_channels_last = + is_channels_last_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_channels_last_dim_order(b.dim_order().data(), b.dim_order().size()) && + is_channels_last_dim_order(c.dim_order().data(), c.dim_order().size()); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Three input tensors have different dim orders"); + return true; +} namespace internal { Error share_tensor_data( diff --git a/runtime/core/exec_aten/util/test/targets.bzl b/runtime/core/exec_aten/util/test/targets.bzl index cbd31013b5b..615b7c99a44 100644 --- a/runtime/core/exec_aten/util/test/targets.bzl +++ b/runtime/core/exec_aten/util/test/targets.bzl @@ -16,16 +16,6 @@ def define_common_targets(): ], ) - runtime.cxx_test( - name = "tensor_util_test", - srcs = ["tensor_util_test.cpp"], - deps = [ - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - "//executorch/runtime/core/exec_aten/util:scalar_type_util", - "//executorch/runtime/core/exec_aten/util:tensor_util", - ], - ) - runtime.cxx_test( name = "operator_impl_example_test", srcs = ["operator_impl_example_test.cpp"], @@ -44,3 +34,15 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/util:tensor_util", ], ) + + for aten_mode in (True, False): + aten_suffix = "_aten" if aten_mode else "" + runtime.cxx_test( + name = "tensor_util_test" + aten_suffix, + srcs = ["tensor_util_test.cpp"], + deps = [ + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/runtime/core/exec_aten/util:scalar_type_util", + "//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix, + ], + ) diff --git a/runtime/core/exec_aten/util/test/tensor_util_test.cpp b/runtime/core/exec_aten/util/test/tensor_util_test.cpp index 53ff06966c2..88588dade68 100644 --- a/runtime/core/exec_aten/util/test/tensor_util_test.cpp +++ b/runtime/core/exec_aten/util/test/tensor_util_test.cpp @@ -14,8 +14,6 @@ #include #include -#include - using namespace ::testing; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -553,3 +551,57 @@ TEST_F(TensorUtilTest, ResizeZeroDimTensor) { executorch::runtime::Error::Ok); EXPECT_EQ(a.dim(), 0); } + +TEST_F(TensorUtilTest, SameDimOrderContiguous) { + using namespace torch::executor; + // Three different tensors with the same shape and same dim order + // ([0, 1, 2, 3]), but different dtypes and contents. + std::vector sizes = {3, 5, 2, 1}; + Tensor a = tf_byte_.ones(sizes); + Tensor b = tf_int_.zeros(sizes); + Tensor c = tf_float_.full(sizes, 0.1); + + // The tensors have the same dim order, should pass the following checks. + EXPECT_TRUE(tensors_have_same_dim_order(a, b)); + EXPECT_TRUE(tensors_have_same_dim_order(b, a)); + EXPECT_TRUE(tensors_have_same_dim_order(a, b, c)); + EXPECT_TRUE(tensors_have_same_dim_order(b, c, a)); + EXPECT_TRUE(tensors_have_same_dim_order(c, a, b)); +} + +TEST_F(TensorUtilTest, SameDimOrderChannelsLast) { + using namespace torch::executor; + // Three different tensors with the same shape and same dim order + // ([0, 2, 3, 1]), but different dtypes and contents. + std::vector sizes = {3, 5, 2, 1}; + Tensor a = tf_byte_.full_channels_last(sizes, 1); + Tensor b = tf_int_.full_channels_last(sizes, 0); + Tensor c = tf_float_.full_channels_last(sizes, 0.1); + + // The tensors have the same dim order, should pass the following checks. + EXPECT_TRUE(tensors_have_same_dim_order(a, b)); + EXPECT_TRUE(tensors_have_same_dim_order(b, a)); + EXPECT_TRUE(tensors_have_same_dim_order(a, b, c)); + EXPECT_TRUE(tensors_have_same_dim_order(b, c, a)); + EXPECT_TRUE(tensors_have_same_dim_order(c, a, b)); +} + +TEST_F(TensorUtilTest, SameShapesDifferentDimOrder) { + using namespace torch::executor; + // Three different tensors with the same shape but different dtypes and + // contents, where b and c have the same dim order ([0, 2, 3, 1]) while a is + // different ([0, 1, 2, 3]). + std::vector sizes = {3, 5, 2, 1}; + Tensor a = tf_byte_.ones(sizes); + Tensor b = tf_int_.full_channels_last(sizes, 0); + Tensor c = tf_float_.full_channels_last(sizes, 0.1); + + // Not the same dim order. Chec + EXPECT_FALSE(tensors_have_same_dim_order(a, b)); + EXPECT_FALSE(tensors_have_same_dim_order(b, a)); + + // Test with a mismatching tensor in all positions, where the other two agree. + EXPECT_FALSE(tensors_have_same_dim_order(a, b, c)); + EXPECT_FALSE(tensors_have_same_dim_order(a, c, b)); + EXPECT_FALSE(tensors_have_same_dim_order(c, b, a)); +}