From 83fec013b938dfeea58832d09d494f0156b4d553 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 2 Sep 2024 22:26:23 -0700 Subject: [PATCH] introduce dim order tests to op test (#2637) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2637 This diff introduces dim order sanity check utils, as well as dim-order related test to operator tests, to help our system maintain its correctness when introducing new dim order ([0, 2, 3, 1]) which we never support before. The goal is checking whether or not every operator support its input's memory format, and using related tests for regular tests. The high levels of sanity check and test will be: 1. the dim order of input and output should be same. 2. the dim order of all input tensors should be same, unless operaotr-specific requirement for some input (e.g. some operator may request some input have to be contiguous, although I haven't found the actual example yet.) 3. make the operator support as much dim order as possible (e,g, if a operator can support both contiguous and channels last, then the sanity check has to make the both input valid.) I also updated `op_abs` in this diff to demonstrate how the sanity check as well as tests will be inserted. Differential Revision: D55227304 --- kernels/portable/cpu/op_abs.cpp | 2 + kernels/test/TestUtil.h | 36 ++++ kernels/test/op_abs_test.cpp | 25 +++ .../exec_aten/testing_util/tensor_factory.h | 156 ++++++++++++++++-- .../testing_util/test/tensor_factory_test.cpp | 2 +- runtime/core/exec_aten/util/tensor_util.h | 116 ++++++++----- .../core/exec_aten/util/tensor_util_aten.cpp | 58 +++++++ .../exec_aten/util/tensor_util_portable.cpp | 34 ++++ runtime/core/exec_aten/util/test/targets.bzl | 22 +-- .../exec_aten/util/test/tensor_util_test.cpp | 56 ++++++- 10 files changed, 443 insertions(+), 64 deletions(-) diff --git a/kernels/portable/cpu/op_abs.cpp b/kernels/portable/cpu/op_abs.cpp index 0dd925a0e25..9c2c219832d 100644 --- a/kernels/portable/cpu/op_abs.cpp +++ b/kernels/portable/cpu/op_abs.cpp @@ -28,6 +28,8 @@ Tensor& abs_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { "Failed to resize output tensor."); ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] { apply_unary_map_fn( diff --git a/kernels/test/TestUtil.h b/kernels/test/TestUtil.h index ed72dbc4128..8d782d3c2a9 100644 --- a/kernels/test/TestUtil.h +++ b/kernels/test/TestUtil.h @@ -30,6 +30,22 @@ #define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \ EXPECT_ANY_THROW(_statement) +#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ + tf, op, input_contiguous, expected_contiguous, channels_last_support) \ + Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ + Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ + \ + Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ + Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ + \ + Tensor ret = op(input_channels_last, output_channels_last); \ + if (channels_last_support) { \ + EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ + } else { \ + EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ + } \ + EXPECT_TENSOR_EQ(output_channels_last, ret); + #else #define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ @@ -52,6 +68,26 @@ } \ } while (false) +#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \ + tf, op, input_contiguous, expected_contiguous, channels_last_support) \ + Tensor input_channels_last = tf.channels_last_like(input_contiguous); \ + Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \ + \ + Tensor output_contiguous = tf.zeros_like(expected_contiguous); \ + Tensor output_channels_last = tf.channels_last_like(output_contiguous); \ + \ + Tensor ret = op(input_channels_last, output_channels_last); \ + if (channels_last_support) { \ + EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \ + } else { \ + EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \ + } \ + EXPECT_TENSOR_EQ(output_channels_last, ret); \ + ET_EXPECT_KERNEL_FAILURE( \ + context_, op(input_channels_last, output_contiguous)); \ + ET_EXPECT_KERNEL_FAILURE( \ + context_, op(input_contiguous, output_channels_last)); + #endif // USE_ATEN_LIB /* diff --git a/kernels/test/op_abs_test.cpp b/kernels/test/op_abs_test.cpp index b54cd971567..f596d586d90 100644 --- a/kernels/test/op_abs_test.cpp +++ b/kernels/test/op_abs_test.cpp @@ -38,3 +38,28 @@ TEST_F(OpAbsTest, SanityCheck) { EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, expected); } + +TEST_F(OpAbsTest, MemoryFormatCheck) { + TensorFactory tf; + + std::vector sizes = {2, 3, 1, 5}; + + Tensor input_contiguous = + tf.make(sizes, {0.8737, 0.5359, 0.3743, -0.3040, -0.7800, -0.2306, + -0.7684, -0.5364, 0.3478, -0.3289, 0.0829, 0.2939, + -0.8211, 0.8572, -0.0802, 0.9252, -0.2093, 0.9013, + -0.4197, 0.3987, -0.5291, -0.5567, 0.2691, 0.7819, + -0.8009, -0.4286, -0.9299, 0.2143, 0.2565, -0.5701}); + Tensor expected_contiguous = tf.make( + sizes, {0.8737, 0.5359, 0.3743, 0.3040, 0.7800, 0.2306, 0.7684, 0.5364, + 0.3478, 0.3289, 0.0829, 0.2939, 0.8211, 0.8572, 0.0802, 0.9252, + 0.2093, 0.9013, 0.4197, 0.3987, 0.5291, 0.5567, 0.2691, 0.7819, + 0.8009, 0.4286, 0.9299, 0.2143, 0.2565, 0.5701}); + + ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( + tf, + op_abs_out, + input_contiguous, + expected_contiguous, + /*channels_last_support=*/true); +} diff --git a/runtime/core/exec_aten/testing_util/tensor_factory.h b/runtime/core/exec_aten/testing_util/tensor_factory.h index 8f39cc9911d..3045af55819 100644 --- a/runtime/core/exec_aten/testing_util/tensor_factory.h +++ b/runtime/core/exec_aten/testing_util/tensor_factory.h @@ -3,8 +3,10 @@ #pragma once #include +#include #include +#include #include #include #include @@ -54,7 +56,7 @@ inline size_t sizes_to_numel(const std::vector& sizes) { inline bool check_strides( const std::vector sizes, - const std::vector strides) { + const std::vector strides) { if (sizes.size() != strides.size()) { // The length of stride vector shall equal to size vector. return false; @@ -147,14 +149,14 @@ inline bool check_dim_order( return true; } -inline std::vector strides_from_dim_order( +inline std::vector strides_from_dim_order( const std::vector& sizes, const std::vector& dim_order) { bool legal = check_dim_order(sizes, dim_order); ET_CHECK_MSG(legal, "The input dim_order variable is illegal."); size_t ndim = sizes.size(); - std::vector strides(ndim); + std::vector strides(ndim); strides[dim_order[ndim - 1]] = 1; for (int i = ndim - 2; i >= 0; --i) { uint8_t cur_dim = dim_order[i]; @@ -258,7 +260,7 @@ class TensorFactory { at::Tensor make( const std::vector& sizes, const std::vector& data, - const std::vector strides = {}, + const std::vector strides = {}, ET_UNUSED TensorShapeDynamism dynamism = TensorShapeDynamism::DYNAMIC_UNBOUND) { auto expected_numel = internal::sizes_to_numel(sizes); @@ -344,6 +346,72 @@ class TensorFactory { sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism); } + /** + * Given data in contiguous memory format, returns a new Tensor with the + * specified shape and the same data but in channels last memory format. + * + * @param[in] sizes The sizes of the dimensions of the Tensor. + * @param[in] data The data in contiguous memory format that the Tensor should + * be initialized with. The size of this vector must be equal to the product + * of the elements of `sizes`. + * + * @return A new Tensor with the specified shape and data in channls last + * memory format. + */ + at::Tensor channels_last_like( + const at::Tensor& input, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + ET_CHECK_MSG( + input.sizes().size() == 4, "Only 4D tensors can be channels last"); + + const std::vector sizes( + input.sizes().begin(), input.sizes().end()); + + std::vector contiguous_dim_order(sizes.size()); + for (uint8_t i = 0; i < sizes.size(); i++) { + contiguous_dim_order[i] = i; + } + std::vector contiguous_strides = + internal::strides_from_dim_order(sizes, contiguous_dim_order); + + for (int32_t i = 0; i < input.dim(); i++) { + ET_CHECK_MSG( + input.strides()[i] == contiguous_strides[i], + "Input tensor is not contiguous"); + } + + int32_t N = sizes[0]; + int32_t C = sizes[1]; + int32_t H = sizes[2]; + int32_t W = sizes[3]; + + std::vector contiguous_data( + input.data_ptr(), input.data_ptr() + input.numel()); + std::vector channels_last_data( + N * C * H * W); // Create a new blob with the same total size to contain + // channels_last data + for (int32_t n = 0; n < N; ++n) { + for (int32_t c = 0; c < C; ++c) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + // Calculate the index in the original blob + int32_t old_index = ((n * C + c) * H + h) * W + w; + // Calculate the index in the new blob + int32_t new_index = ((n * H + h) * W + w) * C + c; + // Copy the data + channels_last_data[new_index] = contiguous_data[old_index]; + } + } + } + } + + return make_with_dimorder( + sizes, + channels_last_data, + internal::channels_last_dim_order(sizes.size()), + dynamism); + } + /** * Returns a new Tensor with the specified shape, containing contiguous * data will all elements set to `value`. @@ -459,14 +527,13 @@ class TensorFactory { */ at::Tensor empty_strided( const std::vector& sizes, - const std::vector& strides, + const std::vector& strides, ET_UNUSED TensorShapeDynamism dynamism = TensorShapeDynamism::DYNAMIC_UNBOUND) { auto sizes64 = vec_32_to_64(sizes); - auto strides64 = vec_32_to_64(strides); return at::empty_strided( sizes64, - strides64, + strides, DTYPE, /*layout_opt=*/at::Layout::Strided, /*device_opt=*/at::Device(at::DeviceType::CPU), @@ -666,7 +733,7 @@ class TensorFactory { torch::executor::Tensor make( const std::vector& sizes, const std::vector& data, - const std::vector strides = {}, + const std::vector strides = {}, TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { std::vector default_strides; // Generate strides from the tensor dimensions, assuming contiguous data if @@ -746,7 +813,7 @@ class TensorFactory { /** * Returns a new Tensor with the specified shape and data in channels last - * memory layout. + * memory format. * * @param[in] sizes The sizes of the dimensions of the Tensor. * @param[in] data The data that the Tensor should be initialized with. The @@ -764,6 +831,60 @@ class TensorFactory { sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism); } + /** + * Given data in contiguous memory format, returns a new Tensor with the + * specified shape and the same data but in channels last memory format. + * + * @param[in] sizes The sizes of the dimensions of the Tensor. + * @param[in] data The data in contiguous memory format that the Tensor should + * be initialized with. The size of this vector must be equal to the product + * of the elements of `sizes`. + * + * @return A new Tensor with the specified shape and data in channls last + * memory format. + */ + torch::executor::Tensor channels_last_like( + const torch::executor::Tensor& input, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + const std::vector sizes( + input.sizes().begin(), input.sizes().end()); + + ET_CHECK_MSG(sizes.size() == 4, "Only 4D tensors can be channels last"); + ET_CHECK_MSG( + is_contiguous_dim_order(input.dim_order().data(), input.dim()) == true, + "Input tensor is not contiguous"); + int32_t N = sizes[0]; + int32_t C = sizes[1]; + int32_t H = sizes[2]; + int32_t W = sizes[3]; + + std::vector contiguous_data( + input.data_ptr(), input.data_ptr() + input.numel()); + std::vector channels_last_data( + N * C * H * W); // Create a new blob with the same total size to contain + // channels_last data + for (int32_t n = 0; n < N; ++n) { + for (int32_t c = 0; c < C; ++c) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + // Calculate the index in the original blob + int32_t old_index = ((n * C + c) * H + h) * W + w; + // Calculate the index in the new blob + int32_t new_index = ((n * H + h) * W + w) * C + c; + // Copy the data + channels_last_data[new_index] = contiguous_data[old_index]; + } + } + } + } + + return make_with_dimorder( + sizes, + channels_last_data, + internal::channels_last_dim_order(sizes.size()), + dynamism); + } + /** * Returns a new Tensor with the specified shape, containing contiguous data * will all elements set to `value`. @@ -799,7 +920,20 @@ class TensorFactory { /** * Returns a new Tensor with the specified shape, containing contiguous data - * with all `0` elements. + * in channels last memory format with all `0` elements. + * + * @param[in] sizes The sizes of the dimensions of the Tensor. + * @return A new Tensor with the specified shape. + */ + torch::executor::Tensor zeros_channels_last( + const std::vector& sizes, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + return full_channels_last(sizes, 0, dynamism); + } + + /** + * Returns a new Tensor with the specified shape, containing contiguous data + * in contiguous memory format with all `0` elements. * * @param[in] sizes The sizes of the dimensions of the Tensor. * @return A new Tensor with the specified shape. @@ -878,7 +1012,7 @@ class TensorFactory { std::vector sizes_; std::vector data_; std::vector dim_order_; - std::vector strides_; + std::vector strides_; torch::executor::TensorImpl impl_; }; diff --git a/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp b/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp index a2bc36f4814..8681e9553a6 100644 --- a/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp +++ b/runtime/core/exec_aten/testing_util/test/tensor_factory_test.cpp @@ -449,7 +449,7 @@ TEST_F(TensorFactoryTest, MakeStridedDataIsCopied) { // Create two tensors using the same input data and strided vector. std::vector data = {1, 2, 3, 4}; - std::vector strides = {1, 2}; + std::vector strides = {1, 2}; Tensor t1 = tf.make(/*sizes=*/{2, 2}, data, strides); Tensor t2 = tf.make(/*sizes=*/{2, 2}, data, strides); diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index b18cd349a62..4dcb0ef9f69 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -235,8 +235,9 @@ */ #define ET_CHECK_CONTIGUOUS(a__) \ ({ \ - const ::exec_aten::ArrayRef strides = a__.strides(); \ - const ::exec_aten::ArrayRef sizes = a__.sizes(); \ + const ::exec_aten::ArrayRef strides = \ + a__.strides(); \ + const ::exec_aten::ArrayRef sizes = a__.sizes(); \ ET_CHECK_MSG( \ strides[strides.size() - 1] == 1, \ "The stride of the last dimension shall be 1 for contiguous tensor, " \ @@ -267,8 +268,10 @@ "Two tensors shall have same number of strides, but not %zu and %zu.", \ a__.dim(), \ b__.dim()); \ - const ::exec_aten::ArrayRef a_strides = a__.strides(); \ - const ::exec_aten::ArrayRef b_strides = b__.strides(); \ + const ::exec_aten::ArrayRef a_strides = \ + a__.strides(); \ + const ::exec_aten::ArrayRef b_strides = \ + b__.strides(); \ for (size_t i = 0; i < a__.dim(); i++) { \ ET_CHECK_MSG( \ a_strides[i] == b_strides[i], \ @@ -276,8 +279,8 @@ "but now is %d and %d.", \ i, \ i, \ - a_strides[i], \ - b_strides[i]); \ + (int32_t)a_strides[i], \ + (int32_t)b_strides[i]); \ } \ }) @@ -295,9 +298,12 @@ a__.dim(), \ b__.dim(), \ c__.dim()); \ - const ::exec_aten::ArrayRef a_strides = a__.strides(); \ - const ::exec_aten::ArrayRef b_strides = b__.strides(); \ - const ::exec_aten::ArrayRef c_strides = c__.strides(); \ + const ::exec_aten::ArrayRef a_strides = \ + a__.strides(); \ + const ::exec_aten::ArrayRef b_strides = \ + b__.strides(); \ + const ::exec_aten::ArrayRef c_strides = \ + c__.strides(); \ for (size_t i = 0; i < a__.dim(); i++) { \ ET_CHECK_MSG( \ a_strides[i] == b_strides[i] && b_strides[i] == c_strides[i], \ @@ -306,9 +312,9 @@ i, \ i, \ i, \ - a_strides[i], \ - b_strides[i], \ - c_strides[i]); \ + (int32_t)a_strides[i], \ + (int32_t)b_strides[i], \ + (int32_t)c_strides[i]); \ } \ }) @@ -848,11 +854,11 @@ inline bool tensor_is_scalar(exec_aten::Tensor t) { /** * The expected output size may not be the existing size of any inputs and - * outputs if the operator supports both broadcast and dynamic shape. Therefore - * such operators needs extra space to store the calculated expected output - * size. such dynamic allocation is troublesome in executorch so we can just - * hard code a static value of a relatively small value because users don't - * create high dimensional tensors. + * outputs if the operator supports both broadcast and dynamic shape. + * Therefore such operators needs extra space to store the calculated expected + * output size. such dynamic allocation is troublesome in executorch so we can + * just hard code a static value of a relatively small value because users + * don't create high dimensional tensors. */ constexpr size_t kTensorDimensionLimit = 16; @@ -893,8 +899,8 @@ inline size_t getTrailingDims(const exec_aten::Tensor& tensor, int64_t dim) { * @param[in] tensor The tensor that will be indexed * @param[in] coordinate A n-dimensional array representing the coordinate to * index. It is assumed that the array has kTensorDimensionLimit elements. - * @param[out] index The linear index to element at the specified coordinate in - * the tensor. + * @param[out] index The linear index to element at the specified coordinate + * in the tensor. */ inline size_t coordinateToIndex( const exec_aten::Tensor& tensor, @@ -935,10 +941,10 @@ inline void indexToCoordinate( * * @param[in] tensor The source of the value to extract. * @param[out] out_val The extracted value, on success. - * @returns `true` if a value was extracted, and sets `*out_val` to that value. - * `false` if a value could not be extracted: either it was not an integer - * Scalar Tensor, or the value of that Scalar Tensor could not be represented - * by INT_T. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted: either it was not an + * integer Scalar Tensor, or the value of that Scalar Tensor could not be + * represented by INT_T. */ template < typename INT_T, @@ -973,10 +979,10 @@ bool extract_scalar_tensor(exec_aten::Tensor tensor, INT_T* out_val) { * * @param[in] tensor The source of the value to extract. * @param[out] out_val The extracted value, on success. - * @returns `true` if a value was extracted, and sets `*out_val` to that value. - * `false` if a value could not be extracted: either it was not a floating - * point Scalar Tensor, or the value of that Scalar Tensor could not be - * represented by FLOAT_T. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted: either it was not a + * floating point Scalar Tensor, or the value of that Scalar Tensor could not + * be represented by FLOAT_T. */ template < typename FLOAT_T, @@ -1076,9 +1082,9 @@ ET_NODISCARD Error resize_tensor_impl( * expand the tensor if new size exceeds the current capacity. Currently * fails an ET_CHECK if the tensor cannot be resized. * - * WARNING: Placeholder API until discussion around runtime context is settled, - * will likely move to be a class method on a TensorResizer object passed in - * through runtimeContext. + * WARNING: Placeholder API until discussion around runtime context is + * settled, will likely move to be a class method on a TensorResizer object + * passed in through runtimeContext. */ ET_NODISCARD inline Error resize_tensor( exec_aten::Tensor t, @@ -1091,9 +1097,9 @@ ET_NODISCARD inline Error resize_tensor( * expand the tensor if new size exceeds the current capacity. Currently * fails an ET_CHECK if the tensor cannot be resized. * - * WARNING: Placeholder API until discussion around runtime context is settled, - * will likely move to be a class method on a TensorResizer object passed in - * through runtimeContext. + * WARNING: Placeholder API until discussion around runtime context is + * settled, will likely move to be a class method on a TensorResizer object + * passed in through runtimeContext. */ template < typename T, @@ -1124,8 +1130,8 @@ ET_DEPRECATED inline void resize( /** * Get dim_order of a Tensor and write it to out_dim_order. * @param tensor The tensor where we want to get dim order from. - * @param out_dim_order Pointing to an array of DimOrderType where we write dim - * order into it. + * @param out_dim_order Pointing to an array of DimOrderType where we write + * dim order into it. * @param out_dim_order_size Size of the DimOrderType array. */ ET_NODISCARD Error get_dim_order( @@ -1134,18 +1140,47 @@ ET_NODISCARD Error get_dim_order( size_t out_dim_order_size); /** - * Checks whether a tensor has a valid dim order. If the dim order could not be - * determined, then this function returns false by default. + * Checks whether a tensor has a valid dim order. If the dim order could not + * be determined, then this function returns false by default. */ bool tensor_has_valid_dim_order(exec_aten::Tensor t); /** - * Checks whether a tensor has either the default of channels last dim order. If - * the dim order could not be determined, then this function returns false by - * default. + * Checks whether a tensor has either the default of channels last dim order. + * If the dim order could not be determined, then this function returns false + * by default. */ bool tensor_is_default_or_channels_last_dim_order(exec_aten::Tensor t); +/** + * Asserts that two tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. Also this macro does not support ATen mode since we do not + * support dim order in ATen mode. + * + * TODO(T183094318): Add dim order and related function support for ATen mode. + */ + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b); + +/** + * Asserts that three tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. Also this macro does not support ATen mode since we do not + * support dim order in ATen mode. + * + * TODO(T183094318): Add dim order and related function support for ATen mode. + */ + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b, + const exec_aten::Tensor& c); + /** * Given an n-dimensional coordinate array and an array of tensor strides, * calculates the linear index that can be used to retrieve the value at the @@ -1205,6 +1240,7 @@ using ::executorch::runtime::tensor_is_real_type; using ::executorch::runtime::tensor_is_realh_type; using ::executorch::runtime::tensor_is_realhb_type; using ::executorch::runtime::tensor_is_scalar; +using ::executorch::runtime::tensors_have_same_dim_order; using ::executorch::runtime::tensors_have_same_dtype; using ::executorch::runtime::tensors_have_same_rank; using ::executorch::runtime::tensors_have_same_shape; diff --git a/runtime/core/exec_aten/util/tensor_util_aten.cpp b/runtime/core/exec_aten/util/tensor_util_aten.cpp index c5ff3b52234..91b75c06483 100644 --- a/runtime/core/exec_aten/util/tensor_util_aten.cpp +++ b/runtime/core/exec_aten/util/tensor_util_aten.cpp @@ -77,6 +77,64 @@ inline bool tensor_is_default_or_channels_last_dim_order(at::Tensor t) { return ret_val; } +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b) { + exec_aten::DimOrderType a_dim_order[kTensorDimensionLimit]; + exec_aten::DimOrderType b_dim_order[kTensorDimensionLimit]; + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(a, a_dim_order, a.dim()) == Error::Ok, + "Failed to retrieve dim order from first input tensor!"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(b, b_dim_order, b.dim()) == Error::Ok, + "Failed to retrieve dim order from second input tensor!"); + + bool all_contiguous = is_contiguous_dim_order(a_dim_order, a.dim()) && + is_contiguous_dim_order(b_dim_order, b.dim()); + + bool all_channels_last = is_channels_last_dim_order(a_dim_order, a.dim()) && + is_channels_last_dim_order(b_dim_order, b.dim()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Two input tensors have different dim orders"); + + return true; +} + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b, + const exec_aten::Tensor& c) { + exec_aten::DimOrderType a_dim_order[kTensorDimensionLimit]; + exec_aten::DimOrderType b_dim_order[kTensorDimensionLimit]; + exec_aten::DimOrderType c_dim_order[kTensorDimensionLimit]; + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(a, a_dim_order, a.dim()) == Error::Ok, + "Failed to retrieve dim order from first input tensor!"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(b, b_dim_order, b.dim()) == Error::Ok, + "Failed to retrieve dim order from second input tensor!"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + get_dim_order(c, c_dim_order, c.dim()) == Error::Ok, + "Failed to retrieve dim order from third input tensor!"); + + bool all_contiguous = is_contiguous_dim_order(a_dim_order, a.dim()) && + is_contiguous_dim_order(b_dim_order, b.dim()) && + is_contiguous_dim_order(c_dim_order, c.dim()); + + bool all_channels_last = is_channels_last_dim_order(a_dim_order, a.dim()) && + is_channels_last_dim_order(b_dim_order, b.dim()) && + is_channels_last_dim_order(c_dim_order, c.dim()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Three input tensors have different dim orders"); + + return true; +} + namespace internal { Error share_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) { diff --git a/runtime/core/exec_aten/util/tensor_util_portable.cpp b/runtime/core/exec_aten/util/tensor_util_portable.cpp index c7872d1499a..7e9a15f09a9 100644 --- a/runtime/core/exec_aten/util/tensor_util_portable.cpp +++ b/runtime/core/exec_aten/util/tensor_util_portable.cpp @@ -73,6 +73,40 @@ bool tensor_is_default_or_channels_last_dim_order(torch::executor::Tensor t) { return ret_val; } +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b) { + bool all_contiguous = + is_contiguous_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_contiguous_dim_order(b.dim_order().data(), b.dim_order().size()); + bool all_channels_last = + is_channels_last_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_channels_last_dim_order(b.dim_order().data(), b.dim_order().size()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Two input tensors have different dim orders"); + + return true; +} + +bool tensors_have_same_dim_order( + const exec_aten::Tensor& a, + const exec_aten::Tensor& b, + const exec_aten::Tensor& c) { + bool all_contiguous = + is_contiguous_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_contiguous_dim_order(b.dim_order().data(), b.dim_order().size()) && + is_contiguous_dim_order(c.dim_order().data(), c.dim_order().size()); + bool all_channels_last = + is_channels_last_dim_order(a.dim_order().data(), a.dim_order().size()) && + is_channels_last_dim_order(b.dim_order().data(), b.dim_order().size()) && + is_channels_last_dim_order(c.dim_order().data(), c.dim_order().size()); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + all_contiguous || all_channels_last, + "Three input tensors have different dim orders"); + return true; +} namespace internal { Error share_tensor_data( diff --git a/runtime/core/exec_aten/util/test/targets.bzl b/runtime/core/exec_aten/util/test/targets.bzl index cbd31013b5b..615b7c99a44 100644 --- a/runtime/core/exec_aten/util/test/targets.bzl +++ b/runtime/core/exec_aten/util/test/targets.bzl @@ -16,16 +16,6 @@ def define_common_targets(): ], ) - runtime.cxx_test( - name = "tensor_util_test", - srcs = ["tensor_util_test.cpp"], - deps = [ - "//executorch/runtime/core/exec_aten/testing_util:tensor_util", - "//executorch/runtime/core/exec_aten/util:scalar_type_util", - "//executorch/runtime/core/exec_aten/util:tensor_util", - ], - ) - runtime.cxx_test( name = "operator_impl_example_test", srcs = ["operator_impl_example_test.cpp"], @@ -44,3 +34,15 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/util:tensor_util", ], ) + + for aten_mode in (True, False): + aten_suffix = "_aten" if aten_mode else "" + runtime.cxx_test( + name = "tensor_util_test" + aten_suffix, + srcs = ["tensor_util_test.cpp"], + deps = [ + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/runtime/core/exec_aten/util:scalar_type_util", + "//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix, + ], + ) diff --git a/runtime/core/exec_aten/util/test/tensor_util_test.cpp b/runtime/core/exec_aten/util/test/tensor_util_test.cpp index 53ff06966c2..88588dade68 100644 --- a/runtime/core/exec_aten/util/test/tensor_util_test.cpp +++ b/runtime/core/exec_aten/util/test/tensor_util_test.cpp @@ -14,8 +14,6 @@ #include #include -#include - using namespace ::testing; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -553,3 +551,57 @@ TEST_F(TensorUtilTest, ResizeZeroDimTensor) { executorch::runtime::Error::Ok); EXPECT_EQ(a.dim(), 0); } + +TEST_F(TensorUtilTest, SameDimOrderContiguous) { + using namespace torch::executor; + // Three different tensors with the same shape and same dim order + // ([0, 1, 2, 3]), but different dtypes and contents. + std::vector sizes = {3, 5, 2, 1}; + Tensor a = tf_byte_.ones(sizes); + Tensor b = tf_int_.zeros(sizes); + Tensor c = tf_float_.full(sizes, 0.1); + + // The tensors have the same dim order, should pass the following checks. + EXPECT_TRUE(tensors_have_same_dim_order(a, b)); + EXPECT_TRUE(tensors_have_same_dim_order(b, a)); + EXPECT_TRUE(tensors_have_same_dim_order(a, b, c)); + EXPECT_TRUE(tensors_have_same_dim_order(b, c, a)); + EXPECT_TRUE(tensors_have_same_dim_order(c, a, b)); +} + +TEST_F(TensorUtilTest, SameDimOrderChannelsLast) { + using namespace torch::executor; + // Three different tensors with the same shape and same dim order + // ([0, 2, 3, 1]), but different dtypes and contents. + std::vector sizes = {3, 5, 2, 1}; + Tensor a = tf_byte_.full_channels_last(sizes, 1); + Tensor b = tf_int_.full_channels_last(sizes, 0); + Tensor c = tf_float_.full_channels_last(sizes, 0.1); + + // The tensors have the same dim order, should pass the following checks. + EXPECT_TRUE(tensors_have_same_dim_order(a, b)); + EXPECT_TRUE(tensors_have_same_dim_order(b, a)); + EXPECT_TRUE(tensors_have_same_dim_order(a, b, c)); + EXPECT_TRUE(tensors_have_same_dim_order(b, c, a)); + EXPECT_TRUE(tensors_have_same_dim_order(c, a, b)); +} + +TEST_F(TensorUtilTest, SameShapesDifferentDimOrder) { + using namespace torch::executor; + // Three different tensors with the same shape but different dtypes and + // contents, where b and c have the same dim order ([0, 2, 3, 1]) while a is + // different ([0, 1, 2, 3]). + std::vector sizes = {3, 5, 2, 1}; + Tensor a = tf_byte_.ones(sizes); + Tensor b = tf_int_.full_channels_last(sizes, 0); + Tensor c = tf_float_.full_channels_last(sizes, 0.1); + + // Not the same dim order. Chec + EXPECT_FALSE(tensors_have_same_dim_order(a, b)); + EXPECT_FALSE(tensors_have_same_dim_order(b, a)); + + // Test with a mismatching tensor in all positions, where the other two agree. + EXPECT_FALSE(tensors_have_same_dim_order(a, b, c)); + EXPECT_FALSE(tensors_have_same_dim_order(a, c, b)); + EXPECT_FALSE(tensors_have_same_dim_order(c, b, a)); +}