From 9cc3a7b60e4e6381c36017dff4ae200d5342bd6f Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 11 Sep 2024 08:12:05 -0700 Subject: [PATCH] Add helper function to create empty, full, ones and zeros tensors. (#5261) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5261 . Differential Revision: D62486240 --- extension/tensor/targets.bzl | 1 + extension/tensor/tensor_ptr.h | 27 ++- extension/tensor/tensor_ptr_maker.cpp | 114 +++++++++ extension/tensor/tensor_ptr_maker.h | 221 +++++++++++++++++- .../tensor/test/tensor_ptr_maker_test.cpp | 139 +++++++++++ extension/tensor/test/tensor_ptr_test.cpp | 16 ++ 6 files changed, 513 insertions(+), 5 deletions(-) create mode 100644 extension/tensor/tensor_ptr_maker.cpp diff --git a/extension/tensor/targets.bzl b/extension/tensor/targets.bzl index 4998b5cf15b..8493d093fa1 100644 --- a/extension/tensor/targets.bzl +++ b/extension/tensor/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): srcs = [ "tensor_impl_ptr.cpp", "tensor_ptr.cpp", + "tensor_ptr_maker.cpp", ], exported_headers = [ "tensor.h", diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index c760de4f038..f477199a3e1 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -142,8 +142,7 @@ inline TensorPtr make_tensor_ptr( * * This template overload is specialized for cases where the tensor data is * provided as a vector. The scalar type is automatically deduced from the - * vector's data type. The deleter ensures that the data vector is properly - * managed and its lifetime is tied to the TensorImpl. + * vector's data type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param sizes A vector specifying the size of each dimension. @@ -174,8 +173,7 @@ TensorPtr make_tensor_ptr( * * This template overload is specialized for cases where the tensor data is * provided as a vector. The scalar type is automatically deduced from the - * vector's data type. The deleter ensures that the data vector is properly - * managed and its lifetime is tied to the TensorImpl. + * vector's data type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param data A vector containing the tensor's data. @@ -190,6 +188,27 @@ TensorPtr make_tensor_ptr( return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), dynamism)); } +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload allows creating a Tensor from an initializer list + * of data. The scalar type is automatically deduced from the type of the + * initializer list's elements. + * + * @tparam T The C++ type of the tensor elements, deduced from the initializer + * list. + * @param data An initializer list containing the tensor's data. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template +TensorPtr make_tensor_ptr( + std::initializer_list data, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return make_tensor_ptr(std::vector(data), dynamism); +} + /** * Creates a TensorPtr that manages a Tensor with the specified properties. * diff --git a/extension/tensor/tensor_ptr_maker.cpp b/extension/tensor/tensor_ptr_maker.cpp new file mode 100644 index 00000000000..1c7b0efe589 --- /dev/null +++ b/extension/tensor/tensor_ptr_maker.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace extension { +namespace { +template < + typename INT_T, + typename std::enable_if< + std::is_integral::value && !std::is_same::value, + bool>::type = true> +bool extract_scalar(exec_aten::Scalar scalar, INT_T* out_val) { + if (!scalar.isIntegral(/*includeBool=*/false)) { + return false; + } + int64_t val = scalar.to(); + if (val < std::numeric_limits::lowest() || + val > std::numeric_limits::max()) { + return false; + } + *out_val = static_cast(val); + return true; +} + +template < + typename FLOAT_T, + typename std::enable_if::value, bool>:: + type = true> +bool extract_scalar(exec_aten::Scalar scalar, FLOAT_T* out_val) { + double val; + if (scalar.isFloatingPoint()) { + val = scalar.to(); + if (std::isfinite(val) && + (val < std::numeric_limits::lowest() || + val > std::numeric_limits::max())) { + return false; + } + } else if (scalar.isIntegral(/*includeBool=*/false)) { + val = static_cast(scalar.to()); + } else { + return false; + } + *out_val = static_cast(val); + return true; +} + +template < + typename BOOL_T, + typename std::enable_if::value, bool>::type = + true> +bool extract_scalar(exec_aten::Scalar scalar, BOOL_T* out_val) { + if (scalar.isIntegral(false)) { + *out_val = static_cast(scalar.to()); + return true; + } + if (scalar.isBoolean()) { + *out_val = scalar.to(); + return true; + } + return false; +} + +#define ET_EXTRACT_SCALAR(scalar, out_val) \ + ET_CHECK_MSG( \ + extract_scalar(scalar, &out_val), \ + #scalar " could not be extracted: wrong type or out of range"); + +} // namespace + +TensorPtr empty_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + std::vector data( + exec_aten::compute_numel(sizes.data(), sizes.size()) * + exec_aten::elementSize(type)); + return make_tensor_ptr( + type, + std::move(sizes), + std::move(data), + {}, + std::move(strides), + dynamism); +} + +TensorPtr full_strided( + std::vector sizes, + std::vector strides, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + auto tensor = + empty_strided(std::move(sizes), std::move(strides), type, dynamism); + ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] { + CTYPE value; + ET_EXTRACT_SCALAR(fill_value, value); + std::fill( + tensor->mutable_data_ptr(), + tensor->mutable_data_ptr() + tensor->numel(), + value); + }); + return tensor; +} + +} // namespace extension +} // namespace executorch diff --git a/extension/tensor/tensor_ptr_maker.h b/extension/tensor/tensor_ptr_maker.h index fd97e53dbca..132bd1f12c6 100644 --- a/extension/tensor/tensor_ptr_maker.h +++ b/extension/tensor/tensor_ptr_maker.h @@ -15,7 +15,7 @@ namespace extension { /** * A helper class for creating TensorPtr instances from raw data and tensor - * properties. Note the the TensorPtr created by this class will not own the + * properties. Note that the TensorPtr created by this class will not own the * data, so it must outlive the TensorPtr. * * TensorPtrMaker provides a fluent interface for specifying various properties @@ -31,6 +31,7 @@ class TensorPtrMaker final { // But it is movable. TensorPtrMaker(TensorPtrMaker&&) = default; TensorPtrMaker& operator=(TensorPtrMaker&&) = default; + /** * Sets the scalar type of the tensor elements. * @@ -278,5 +279,223 @@ inline TensorPtr from_blob( .make_tensor_ptr(); } +/** + * Creates a TensorPtr with the specified sizes, strides, and properties. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. The tensor is created with the + * specified strides. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr empty_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates an empty TensorPtr with the same size and properties as the given + * tensor. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr empty_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return empty_strided( + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + type, + dynamism); +} + +/** + * Creates an empty TensorPtr with the specified sizes and properties. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr empty( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return empty_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with the specified value. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr full_strided( + std::vector sizes, + std::vector strides, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with the specified value, with the same size and + * properties as another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr full_like( + const TensorPtr& other, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return full_strided( + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + fill_value, + type, + dynamism); +} + +/** + * Creates a TensorPtr filled with the specified value. + * + * @param sizes A vector specifying the size of each dimension. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr full( + std::vector sizes, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_strided(std::move(sizes), {}, fill_value, type, dynamism); +} + +/** + * Creates a TensorPtr that holds a scalar value. + * + * @param value The scalar value to create the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created scalar Tensor. + */ +inline TensorPtr scalar_tensor( + exec_aten::Scalar value, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full({}, value, type, dynamism); +} + +/** + * Creates a TensorPtr filled with ones, with the same size and properties as + * another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the `other` tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr ones_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_like(other, 1, type, dynamism); +} + +/** + * Creates a TensorPtr filled with ones. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr ones( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full(std::move(sizes), 1, type, dynamism); +} + +/** + * Creates a TensorPtr filled with zeros, with the same size and properties as + * another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the `other` tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr zeros_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_like(other, 0, type, dynamism); +} + +/** + * Creates a TensorPtr filled with zeros. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr zeros( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full(std::move(sizes), 0, type, dynamism); +} + } // namespace extension } // namespace executorch diff --git a/extension/tensor/test/tensor_ptr_maker_test.cpp b/extension/tensor/test/tensor_ptr_maker_test.cpp index d1b4179a260..7530a3709ab 100644 --- a/extension/tensor/test/tensor_ptr_maker_test.cpp +++ b/extension/tensor/test/tensor_ptr_maker_test.cpp @@ -178,3 +178,142 @@ TEST_F(TensorPtrMakerTest, TensorDeleterReleasesCapturedSharedPtr) { EXPECT_TRUE(deleter_called); EXPECT_EQ(data_ptr.use_count(), 1); } + +TEST_F(TensorPtrMakerTest, CreateEmpty) { + auto tensor = empty({4, 5}); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + + auto tensor2 = empty({4, 5}, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + + auto tensor3 = empty({4, 5}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + + auto tensor4 = empty({4, 5}, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); +} + +TEST_F(TensorPtrMakerTest, CreateFull) { + auto tensor = full({4, 5}, 7); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 7); + + auto tensor2 = full({4, 5}, 3, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 3); + + auto tensor3 = full({4, 5}, 9, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->const_data_ptr()[0], 9); + + auto tensor4 = full({4, 5}, 11, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->const_data_ptr()[0], 11); +} + +TEST_F(TensorPtrMakerTest, CreateScalar) { + auto tensor = scalar_tensor(3.14f); + + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->numel(), 1); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 3.14f); + + auto tensor2 = scalar_tensor(5, exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor2->dim(), 0); + EXPECT_EQ(tensor2->numel(), 1); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 5); + + auto tensor3 = scalar_tensor(7.0, exec_aten::ScalarType::Double); + + EXPECT_EQ(tensor3->dim(), 0); + EXPECT_EQ(tensor3->numel(), 1); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor3->const_data_ptr()[0], 7.0); +} + +TEST_F(TensorPtrMakerTest, CreateOnes) { + auto tensor = ones({4, 5}); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 1); + + auto tensor2 = ones({4, 5}, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 1); + + auto tensor3 = ones({4, 5}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->const_data_ptr()[0], 1); + + auto tensor4 = ones({4, 5}, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->const_data_ptr()[0], 1); +} + +TEST_F(TensorPtrMakerTest, CreateZeros) { + auto tensor = zeros({4, 5}); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 0); + + auto tensor2 = zeros({4, 5}, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 0); + + auto tensor3 = zeros({4, 5}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->const_data_ptr()[0], 0); + + auto tensor4 = zeros({4, 5}, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->const_data_ptr()[0], 0); +} diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index d5582630494..653e2ef98d7 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -197,6 +197,18 @@ TEST_F(TensorPtrTest, TensorOwningEmptyData) { EXPECT_EQ(tensor->strides()[0], 5); EXPECT_EQ(tensor->strides()[1], 1); EXPECT_EQ(tensor->data_ptr(), nullptr); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); +} + +TEST_F(TensorPtrTest, TensorImplDataOnly) { + auto tensor = make_tensor_ptr({1.0f, 2.0f, 3.0f, 4.0f}); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 1.0); + EXPECT_EQ(tensor->const_data_ptr()[3], 4.0); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); } TEST_F(TensorPtrTest, TensorImplDataOnlyDoubleType) { @@ -208,6 +220,7 @@ TEST_F(TensorPtrTest, TensorImplDataOnlyDoubleType) { EXPECT_EQ(tensor->strides()[0], 1); EXPECT_EQ(tensor->const_data_ptr()[0], 1.0); EXPECT_EQ(tensor->const_data_ptr()[3], 4.0); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Double); } TEST_F(TensorPtrTest, TensorImplDataOnlyInt32Type) { @@ -219,6 +232,7 @@ TEST_F(TensorPtrTest, TensorImplDataOnlyInt32Type) { EXPECT_EQ(tensor->strides()[0], 1); EXPECT_EQ(tensor->const_data_ptr()[0], 10); EXPECT_EQ(tensor->const_data_ptr()[3], 40); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Int); } TEST_F(TensorPtrTest, TensorImplDataOnlyInt64Type) { @@ -230,6 +244,7 @@ TEST_F(TensorPtrTest, TensorImplDataOnlyInt64Type) { EXPECT_EQ(tensor->strides()[0], 1); EXPECT_EQ(tensor->const_data_ptr()[0], 100); EXPECT_EQ(tensor->const_data_ptr()[3], 400); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Long); } TEST_F(TensorPtrTest, TensorImplDataOnlyUint8Type) { @@ -241,6 +256,7 @@ TEST_F(TensorPtrTest, TensorImplDataOnlyUint8Type) { EXPECT_EQ(tensor->strides()[0], 1); EXPECT_EQ(tensor->const_data_ptr()[0], 10); EXPECT_EQ(tensor->const_data_ptr()[3], 40); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Byte); } TEST_F(TensorPtrTest, TensorImplAmbiguityWithMixedVectors) {