diff --git a/extension/runner_util/managed_tensor.h b/extension/runner_util/managed_tensor.h index 5e2fb62c6f7..16a84a13df4 100644 --- a/extension/runner_util/managed_tensor.h +++ b/extension/runner_util/managed_tensor.h @@ -48,13 +48,23 @@ class ManagedTensor { #ifdef USE_ATEN_LIB tensor_ = torch::from_blob(data, sizes, dtype); #else + // Calculate strides. + strides_ = std::vector(sizes_.size()); + if (sizes_.size() > 0) { + strides_.back() = 1; + for (size_t i = strides_.size() - 1; i > 0; --i) { + strides_[i - 1] = strides_[i] * sizes_[i]; + } + } + + // Allocate TensorImpl. tensor_impl_ = std::make_unique( dtype, sizes_.size(), sizes_.data(), data, - nullptr, - nullptr, + /*dim_order=*/nullptr, + strides_.data(), TensorShapeDynamism::DYNAMIC_BOUND); #endif } @@ -80,6 +90,7 @@ class ManagedTensor { private: std::unique_ptr tensor_impl_; std::vector sizes_; + std::vector strides_; #ifdef USE_ATEN_LIB Tensor tensor_; #endif diff --git a/extension/runner_util/test/managed_tensor_test.cpp b/extension/runner_util/test/managed_tensor_test.cpp index d5234570f43..b511cdbcf17 100644 --- a/extension/runner_util/test/managed_tensor_test.cpp +++ b/extension/runner_util/test/managed_tensor_test.cpp @@ -25,8 +25,9 @@ class ManagedTensorTest : public ::testing::Test { void SetUp() override { torch::executor::runtime_init(); - data_ = {1, 2, 3, 4, 5, 6}; - sizes_ = {2, 3}; + data_ = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + sizes_ = {2, 3, 4}; + expected_strides_ = {12, 4, 1}; managed_tensor_ = std::make_unique(data_.data(), sizes_, ScalarType::Long); } @@ -34,6 +35,7 @@ class ManagedTensorTest : public ::testing::Test { protected: std::vector data_; std::vector sizes_; + std::vector expected_strides_; std::unique_ptr managed_tensor_; }; @@ -43,6 +45,9 @@ TEST_F(ManagedTensorTest, Smoke) { EXPECT_EQ(tensor.sizes(), ArrayRef(sizes_.data(), sizes_.size())); EXPECT_EQ(tensor.scalar_type(), ScalarType::Long); EXPECT_EQ(tensor.const_data_ptr(), data_.data()); + for (size_t i = 0; i < expected_strides_.size(); ++i) { + EXPECT_EQ(tensor.strides()[i], expected_strides_[i]); + } } TEST_F(ManagedTensorTest, ResizeWithUpdatedRank) { @@ -50,17 +55,17 @@ TEST_F(ManagedTensorTest, ResizeWithUpdatedRank) { // https://github.com/google/googletest/issues/2834 #if !GTEST_OS_IOS EXPECT_EXIT( - managed_tensor_->resize(std::vector{2, 3, 4}), + managed_tensor_->resize(std::vector{2, 3, 4, 5}), ::testing::KilledBySignal(SIGABRT), ""); #endif } TEST_F(ManagedTensorTest, ResizeShrink) { - managed_tensor_->resize(std::vector{2, 2}); + managed_tensor_->resize(std::vector{2, 2, 2}); const auto tensor = managed_tensor_->get_aliasing_tensor(); - std::vector expected_sizes = {2, 2}; + std::vector expected_sizes = {2, 2, 2}; EXPECT_EQ( tensor.sizes(), ArrayRef(expected_sizes.data(), expected_sizes.size())); @@ -69,10 +74,10 @@ TEST_F(ManagedTensorTest, ResizeShrink) { } TEST_F(ManagedTensorTest, Resize) { - managed_tensor_->resize(std::vector{3, 2}); + managed_tensor_->resize(std::vector{4, 3, 2}); const auto tensor = managed_tensor_->get_aliasing_tensor(); - std::vector expected_sizes = {3, 2}; + std::vector expected_sizes = {4, 3, 2}; EXPECT_EQ( tensor.sizes(), ArrayRef(expected_sizes.data(), expected_sizes.size()));