diff --git a/examples/llm_manual/managed_tensor.h b/examples/llm_manual/managed_tensor.h index d401ae4d18b..d870f4861e6 100644 --- a/examples/llm_manual/managed_tensor.h +++ b/examples/llm_manual/managed_tensor.h @@ -30,28 +30,21 @@ class ManagedTensor { using DimOrderType = exec_aten::DimOrderType; /// The type used for elements of `strides()`. using StridesType = exec_aten::StridesType; + ManagedTensor() = delete; explicit ManagedTensor( void* data, const std::vector& sizes, ScalarType dtype) - : dtype_(dtype), sizes_(sizes), data_ptr_(data) { - ssize_t dim = sizes.size(); - dim_order_.resize(dim); - strides_.resize(dim); - for (size_t i = 0; i < dim; ++i) { - dim_order_[i] = i; - } - dim_order_to_stride_nocheck( - sizes.data(), dim_order_.data(), dim, strides_.data()); + : sizes_(sizes) { tensor_impl_ = std::make_unique( - dtype_, - dim, + dtype, + sizes_.size(), sizes_.data(), - data_ptr_, - dim_order_.data(), - strides_.data(), + data, + nullptr, + nullptr, TensorShapeDynamism::DYNAMIC_BOUND); } @@ -63,12 +56,9 @@ class ManagedTensor { } private: - void* data_ptr_ = nullptr; std::unique_ptr tensor_impl_; std::vector sizes_; - std::vector strides_; - std::vector dim_order_; - ScalarType dtype_; }; + } // namespace executor } // namespace torch diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index bf7ba8b6ce3..4cd4025dcb9 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -20,8 +20,8 @@ #endif #include #include -#include #include +#include #include namespace torch { @@ -107,25 +107,39 @@ struct type_convert< typename remove_const_ref::type, torch::executor::Tensor>>> final { - public: - ATensor val; - std::unique_ptr managed_tensor; - torch::executor::Tensor converted; - std::vector sizes; - explicit type_convert(ATensor value) - : val(value), converted(torch::executor::Tensor(nullptr)) { - for (auto size : val.sizes()) { - sizes.push_back(size); - } - torch::executor::ScalarType scalar_type = - static_cast(val.scalar_type()); - managed_tensor = std::make_unique( - val.mutable_data_ptr(), sizes, scalar_type); - converted = managed_tensor->get_aliasing_tensor(); + explicit type_convert(ATensor value) : value_(value) { + auto sizes = std::make_shared>( + value_.sizes().begin(), value_.sizes().end()); + const ssize_t dim = sizes->size(); + auto dim_order = std::make_shared>(dim); + auto strides = std::make_shared>(dim); + + std::iota(dim_order->begin(), dim_order->end(), 0); + dim_order_to_stride_nocheck( + sizes->data(), dim_order->data(), dim, strides->data()); + + auto tensor_impl = std::make_shared( + static_cast(value_.scalar_type()), + sizes->size(), + sizes->data(), + value_.mutable_data_ptr(), + dim_order->data(), + strides->data()); + + converted_ = std::unique_ptr>( + new Tensor(tensor_impl.get()), + [sizes, dim_order, strides, tensor_impl](Tensor* pointer) { + delete pointer; + }); } + ETensor call() { - return converted; + return *converted_; } + + private: + ATensor value_; + std::unique_ptr> converted_; }; // Tensors: ETen to ATen. @@ -139,21 +153,22 @@ struct type_convert< typename remove_const_ref::type, torch::executor::Tensor>>> final { - public: - ETensor val; - at::Tensor converted; - std::vector sizes; - explicit type_convert(ETensor value) : val(value) { - for (auto size : val.sizes()) { - sizes.push_back(size); - } - c10::ScalarType scalar_type = - static_cast(val.scalar_type()); - converted = at::from_blob(val.mutable_data_ptr(), sizes, scalar_type); + explicit type_convert(ETensor value) + : value_(value), sizes_(value_.sizes().begin(), value_.sizes().end()) { + converted_ = at::from_blob( + value_.mutable_data_ptr(), + sizes_, + static_cast(value_.scalar_type())); } + ATensor call() { - return converted; + return converted_; } + + private: + ETensor value_; + at::Tensor converted_; + std::vector sizes_; }; // Optionals: ATen to ETen. diff --git a/extension/aten_util/targets.bzl b/extension/aten_util/targets.bzl index 6e325830292..b396cb78325 100644 --- a/extension/aten_util/targets.bzl +++ b/extension/aten_util/targets.bzl @@ -27,7 +27,6 @@ def define_common_targets(): ], exported_deps = [ "//executorch/extension/kernel_util:kernel_util", - "//executorch/extension/runner_util:managed_tensor", "//executorch/runtime/core:core", "//executorch/runtime/core:evalue", "//executorch/runtime/core/exec_aten:lib", diff --git a/extension/runner_util/managed_tensor.h b/extension/runner_util/managed_tensor.h index d92f8d19bef..5e2fb62c6f7 100644 --- a/extension/runner_util/managed_tensor.h +++ b/extension/runner_util/managed_tensor.h @@ -37,39 +37,29 @@ class ManagedTensor { using DimOrderType = exec_aten::DimOrderType; /// The type used for elements of `strides()`. using StridesType = exec_aten::StridesType; + ManagedTensor() = delete; explicit ManagedTensor( void* data, const std::vector& sizes, ScalarType dtype) - : dtype_(dtype), sizes_(sizes), data_ptr_(data) { + : sizes_(sizes) { #ifdef USE_ATEN_LIB - tensor_ = torch::from_blob(data, sizes, dtype_); + tensor_ = torch::from_blob(data, sizes, dtype); #else - ssize_t dim = sizes.size(); - dim_order_.resize(dim); - strides_.resize(dim); - for (size_t i = 0; i < dim; ++i) { - dim_order_[i] = i; - } - dim_order_to_stride_nocheck( - sizes.data(), dim_order_.data(), dim, strides_.data()); tensor_impl_ = std::make_unique( - dtype_, - dim, + dtype, + sizes_.size(), sizes_.data(), - data_ptr_, - dim_order_.data(), - strides_.data(), + data, + nullptr, + nullptr, TensorShapeDynamism::DYNAMIC_BOUND); #endif } void resize(const std::vector& new_sizes) { - ET_CHECK_MSG( - new_sizes.size() == sizes_.size(), - "Cannot change rank of a managed tensor"); auto err = resize_tensor( this->get_aliasing_tensor(), exec_aten::ArrayRef(new_sizes.data(), new_sizes.size())); @@ -88,15 +78,12 @@ class ManagedTensor { } private: - ScalarType dtype_; std::unique_ptr tensor_impl_; std::vector sizes_; - std::vector strides_; - std::vector dim_order_; - void* data_ptr_ = nullptr; #ifdef USE_ATEN_LIB Tensor tensor_; #endif }; + } // namespace executor } // namespace torch diff --git a/extension/runner_util/test/managed_tensor_test.cpp b/extension/runner_util/test/managed_tensor_test.cpp index 9c14553ed88..d5234570f43 100644 --- a/extension/runner_util/test/managed_tensor_test.cpp +++ b/extension/runner_util/test/managed_tensor_test.cpp @@ -42,15 +42,6 @@ TEST_F(ManagedTensorTest, Smoke) { EXPECT_EQ(tensor.sizes(), ArrayRef(sizes_.data(), sizes_.size())); EXPECT_EQ(tensor.scalar_type(), ScalarType::Long); - std::vector expected_dim_order = {0, 1}; - EXPECT_EQ( - tensor.dim_order(), - ArrayRef( - expected_dim_order.data(), expected_dim_order.size())); - std::vector expected_strides = {3, 1}; - EXPECT_EQ( - tensor.strides(), - ArrayRef(expected_strides.data(), expected_strides.size())); EXPECT_EQ(tensor.const_data_ptr(), data_.data()); } @@ -74,15 +65,6 @@ TEST_F(ManagedTensorTest, ResizeShrink) { tensor.sizes(), ArrayRef(expected_sizes.data(), expected_sizes.size())); EXPECT_EQ(tensor.scalar_type(), ScalarType::Long); - std::vector expected_dim_order = {0, 1}; - EXPECT_EQ( - tensor.dim_order(), - ArrayRef( - expected_dim_order.data(), expected_dim_order.size())); - std::vector expected_strides = {2, 1}; - EXPECT_EQ( - tensor.strides(), - ArrayRef(expected_strides.data(), expected_strides.size())); EXPECT_EQ(tensor.const_data_ptr(), data_.data()); } @@ -95,14 +77,5 @@ TEST_F(ManagedTensorTest, Resize) { tensor.sizes(), ArrayRef(expected_sizes.data(), expected_sizes.size())); EXPECT_EQ(tensor.scalar_type(), ScalarType::Long); - std::vector expected_dim_order = {0, 1}; - EXPECT_EQ( - tensor.dim_order(), - ArrayRef( - expected_dim_order.data(), expected_dim_order.size())); - std::vector expected_strides = {2, 1}; - EXPECT_EQ( - tensor.strides(), - ArrayRef(expected_strides.data(), expected_strides.size())); EXPECT_EQ(tensor.const_data_ptr(), data_.data()); }