Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernels/portable/cpu/op_abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Tensor& abs_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
"Failed to resize output tensor.");

ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
apply_unary_map_fn(
Expand Down
36 changes: 36 additions & 0 deletions kernels/test/TestUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@
#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \
EXPECT_ANY_THROW(_statement)

#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \
tf, op, input_contiguous, expected_contiguous, channels_last_support) \
Tensor input_channels_last = tf.channels_last_like(input_contiguous); \
Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
\
Tensor output_contiguous = tf.zeros_like(expected_contiguous); \
Tensor output_channels_last = tf.channels_last_like(output_contiguous); \
\
Tensor ret = op(input_channels_last, output_channels_last); \
if (channels_last_support) { \
EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \
} else { \
EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \
} \
EXPECT_TENSOR_EQ(output_channels_last, ret);

#else

#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
Expand All @@ -52,6 +68,26 @@
} \
} while (false)

#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \
tf, op, input_contiguous, expected_contiguous, channels_last_support) \
Tensor input_channels_last = tf.channels_last_like(input_contiguous); \
Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
\
Tensor output_contiguous = tf.zeros_like(expected_contiguous); \
Tensor output_channels_last = tf.channels_last_like(output_contiguous); \
\
Tensor ret = op(input_channels_last, output_channels_last); \
if (channels_last_support) { \
EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \
} else { \
EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \
} \
EXPECT_TENSOR_EQ(output_channels_last, ret); \
ET_EXPECT_KERNEL_FAILURE( \
context_, op(input_channels_last, output_contiguous)); \
ET_EXPECT_KERNEL_FAILURE( \
context_, op(input_contiguous, output_channels_last));

#endif // USE_ATEN_LIB

/*
Expand Down
25 changes: 25 additions & 0 deletions kernels/test/op_abs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,28 @@ TEST_F(OpAbsTest, SanityCheck) {
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(OpAbsTest, MemoryFormatCheck) {
TensorFactory<ScalarType::Float> tf;

std::vector<int32_t> sizes = {2, 3, 1, 5};

Tensor input_contiguous =
tf.make(sizes, {0.8737, 0.5359, 0.3743, -0.3040, -0.7800, -0.2306,
-0.7684, -0.5364, 0.3478, -0.3289, 0.0829, 0.2939,
-0.8211, 0.8572, -0.0802, 0.9252, -0.2093, 0.9013,
-0.4197, 0.3987, -0.5291, -0.5567, 0.2691, 0.7819,
-0.8009, -0.4286, -0.9299, 0.2143, 0.2565, -0.5701});
Tensor expected_contiguous = tf.make(
sizes, {0.8737, 0.5359, 0.3743, 0.3040, 0.7800, 0.2306, 0.7684, 0.5364,
0.3478, 0.3289, 0.0829, 0.2939, 0.8211, 0.8572, 0.0802, 0.9252,
0.2093, 0.9013, 0.4197, 0.3987, 0.5291, 0.5567, 0.2691, 0.7819,
0.8009, 0.4286, 0.9299, 0.2143, 0.2565, 0.5701});

ET_TEST_OP_SUPPORTS_MEMORY_FORMATS(
tf,
op_abs_out,
input_contiguous,
expected_contiguous,
/*channels_last_support=*/true);
}
156 changes: 145 additions & 11 deletions runtime/core/exec_aten/testing_util/tensor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#pragma once

#include <algorithm>
#include <cstdint>

#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/tensor_shape_dynamism.h>
#include <executorch/runtime/platform/assert.h>
Expand Down Expand Up @@ -54,7 +56,7 @@ inline size_t sizes_to_numel(const std::vector<int32_t>& sizes) {

inline bool check_strides(
const std::vector<int32_t> sizes,
const std::vector<int32_t> strides) {
const std::vector<exec_aten::StridesType> strides) {
if (sizes.size() != strides.size()) {
// The length of stride vector shall equal to size vector.
return false;
Expand Down Expand Up @@ -147,14 +149,14 @@ inline bool check_dim_order(
return true;
}

inline std::vector<int32_t> strides_from_dim_order(
inline std::vector<exec_aten::StridesType> strides_from_dim_order(
const std::vector<int32_t>& sizes,
const std::vector<uint8_t>& dim_order) {
bool legal = check_dim_order(sizes, dim_order);
ET_CHECK_MSG(legal, "The input dim_order variable is illegal.");

size_t ndim = sizes.size();
std::vector<int32_t> strides(ndim);
std::vector<exec_aten::StridesType> strides(ndim);
strides[dim_order[ndim - 1]] = 1;
for (int i = ndim - 2; i >= 0; --i) {
uint8_t cur_dim = dim_order[i];
Expand Down Expand Up @@ -258,7 +260,7 @@ class TensorFactory {
at::Tensor make(
const std::vector<int32_t>& sizes,
const std::vector<ctype>& data,
const std::vector<int32_t> strides = {},
const std::vector<exec_aten::StridesType> strides = {},
ET_UNUSED TensorShapeDynamism dynamism =
TensorShapeDynamism::DYNAMIC_UNBOUND) {
auto expected_numel = internal::sizes_to_numel(sizes);
Expand Down Expand Up @@ -344,6 +346,72 @@ class TensorFactory {
sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism);
}

/**
* Given data in contiguous memory format, returns a new Tensor with the
* specified shape and the same data but in channels last memory format.
*
* @param[in] sizes The sizes of the dimensions of the Tensor.
* @param[in] data The data in contiguous memory format that the Tensor should
* be initialized with. The size of this vector must be equal to the product
* of the elements of `sizes`.
*
* @return A new Tensor with the specified shape and data in channls last
* memory format.
*/
at::Tensor channels_last_like(
const at::Tensor& input,
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
ET_CHECK_MSG(
input.sizes().size() == 4, "Only 4D tensors can be channels last");

const std::vector<int32_t> sizes(
input.sizes().begin(), input.sizes().end());

std::vector<uint8_t> contiguous_dim_order(sizes.size());
for (uint8_t i = 0; i < sizes.size(); i++) {
contiguous_dim_order[i] = i;
}
std::vector<exec_aten::StridesType> contiguous_strides =
internal::strides_from_dim_order(sizes, contiguous_dim_order);

for (int32_t i = 0; i < input.dim(); i++) {
ET_CHECK_MSG(
input.strides()[i] == contiguous_strides[i],
"Input tensor is not contiguous");
}

int32_t N = sizes[0];
int32_t C = sizes[1];
int32_t H = sizes[2];
int32_t W = sizes[3];

std::vector<ctype> contiguous_data(
input.data_ptr<ctype>(), input.data_ptr<ctype>() + input.numel());
std::vector<ctype> channels_last_data(
N * C * H * W); // Create a new blob with the same total size to contain
// channels_last data
for (int32_t n = 0; n < N; ++n) {
for (int32_t c = 0; c < C; ++c) {
for (int32_t h = 0; h < H; ++h) {
for (int32_t w = 0; w < W; ++w) {
// Calculate the index in the original blob
int32_t old_index = ((n * C + c) * H + h) * W + w;
// Calculate the index in the new blob
int32_t new_index = ((n * H + h) * W + w) * C + c;
// Copy the data
channels_last_data[new_index] = contiguous_data[old_index];
}
}
}
}

return make_with_dimorder(
sizes,
channels_last_data,
internal::channels_last_dim_order(sizes.size()),
dynamism);
}

/**
* Returns a new Tensor with the specified shape, containing contiguous
* data will all elements set to `value`.
Expand Down Expand Up @@ -459,14 +527,13 @@ class TensorFactory {
*/
at::Tensor empty_strided(
const std::vector<int32_t>& sizes,
const std::vector<int32_t>& strides,
const std::vector<exec_aten::StridesType>& strides,
ET_UNUSED TensorShapeDynamism dynamism =
TensorShapeDynamism::DYNAMIC_UNBOUND) {
auto sizes64 = vec_32_to_64(sizes);
auto strides64 = vec_32_to_64(strides);
return at::empty_strided(
sizes64,
strides64,
strides,
DTYPE,
/*layout_opt=*/at::Layout::Strided,
/*device_opt=*/at::Device(at::DeviceType::CPU),
Expand Down Expand Up @@ -666,7 +733,7 @@ class TensorFactory {
torch::executor::Tensor make(
const std::vector<int32_t>& sizes,
const std::vector<ctype>& data,
const std::vector<int32_t> strides = {},
const std::vector<exec_aten::StridesType> strides = {},
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
std::vector<int32_t> default_strides;
// Generate strides from the tensor dimensions, assuming contiguous data if
Expand Down Expand Up @@ -746,7 +813,7 @@ class TensorFactory {

/**
* Returns a new Tensor with the specified shape and data in channels last
* memory layout.
* memory format.
*
* @param[in] sizes The sizes of the dimensions of the Tensor.
* @param[in] data The data that the Tensor should be initialized with. The
Expand All @@ -764,6 +831,60 @@ class TensorFactory {
sizes, data, internal::channels_last_dim_order(sizes.size()), dynamism);
}

/**
* Given data in contiguous memory format, returns a new Tensor with the
* specified shape and the same data but in channels last memory format.
*
* @param[in] sizes The sizes of the dimensions of the Tensor.
* @param[in] data The data in contiguous memory format that the Tensor should
* be initialized with. The size of this vector must be equal to the product
* of the elements of `sizes`.
*
* @return A new Tensor with the specified shape and data in channls last
* memory format.
*/
torch::executor::Tensor channels_last_like(
const torch::executor::Tensor& input,
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
const std::vector<int32_t> sizes(
input.sizes().begin(), input.sizes().end());

ET_CHECK_MSG(sizes.size() == 4, "Only 4D tensors can be channels last");
ET_CHECK_MSG(
is_contiguous_dim_order(input.dim_order().data(), input.dim()) == true,
"Input tensor is not contiguous");
int32_t N = sizes[0];
int32_t C = sizes[1];
int32_t H = sizes[2];
int32_t W = sizes[3];

std::vector<ctype> contiguous_data(
input.data_ptr<ctype>(), input.data_ptr<ctype>() + input.numel());
std::vector<ctype> channels_last_data(
N * C * H * W); // Create a new blob with the same total size to contain
// channels_last data
for (int32_t n = 0; n < N; ++n) {
for (int32_t c = 0; c < C; ++c) {
for (int32_t h = 0; h < H; ++h) {
for (int32_t w = 0; w < W; ++w) {
// Calculate the index in the original blob
int32_t old_index = ((n * C + c) * H + h) * W + w;
// Calculate the index in the new blob
int32_t new_index = ((n * H + h) * W + w) * C + c;
// Copy the data
channels_last_data[new_index] = contiguous_data[old_index];
}
}
}
}

return make_with_dimorder(
sizes,
channels_last_data,
internal::channels_last_dim_order(sizes.size()),
dynamism);
}

/**
* Returns a new Tensor with the specified shape, containing contiguous data
* will all elements set to `value`.
Expand Down Expand Up @@ -799,7 +920,20 @@ class TensorFactory {

/**
* Returns a new Tensor with the specified shape, containing contiguous data
* with all `0` elements.
* in channels last memory format with all `0` elements.
*
* @param[in] sizes The sizes of the dimensions of the Tensor.
* @return A new Tensor with the specified shape.
*/
torch::executor::Tensor zeros_channels_last(
const std::vector<int32_t>& sizes,
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) {
return full_channels_last(sizes, 0, dynamism);
}

/**
* Returns a new Tensor with the specified shape, containing contiguous data
* in contiguous memory format with all `0` elements.
*
* @param[in] sizes The sizes of the dimensions of the Tensor.
* @return A new Tensor with the specified shape.
Expand Down Expand Up @@ -878,7 +1012,7 @@ class TensorFactory {
std::vector<int32_t> sizes_;
std::vector<ctype> data_;
std::vector<uint8_t> dim_order_;
std::vector<int32_t> strides_;
std::vector<exec_aten::StridesType> strides_;
torch::executor::TensorImpl impl_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ TEST_F(TensorFactoryTest, MakeStridedDataIsCopied) {

// Create two tensors using the same input data and strided vector.
std::vector<int32_t> data = {1, 2, 3, 4};
std::vector<int32_t> strides = {1, 2};
std::vector<exec_aten::StridesType> strides = {1, 2};
Tensor t1 = tf.make(/*sizes=*/{2, 2}, data, strides);
Tensor t2 = tf.make(/*sizes=*/{2, 2}, data, strides);

Expand Down
Loading
Loading