diff --git a/extension/tensor/tensor_impl_ptr.cpp b/extension/tensor/tensor_impl_ptr.cpp index 358acfd1850..e3b065162ba 100644 --- a/extension/tensor/tensor_impl_ptr.cpp +++ b/extension/tensor/tensor_impl_ptr.cpp @@ -17,37 +17,37 @@ namespace executorch { namespace extension { namespace { #ifndef USE_ATEN_LIB -// No-op deleter that does nothing when called. -static void noop_deleter(void*) {} - /** - * Custom deleter for TensorImplPtr that ensures the memory associated with - * dynamic metadata (sizes, dim_order, and strides) is properly managed when the - * TensorImpl is destroyed. - * - * Since TensorImpl does not own the metadata arrays (sizes, dim_order, - * strides), this deleter is responsible for releasing that memory when the - * TensorImpl is destroyed. + * A structure that consolidates the metadata (sizes, dim_order, strides) and + * the data buffer associated with a TensorImpl. Since TensorImpl does not own + * the memory for these metadata arrays or the data itself, this structure + * ensures that they are managed together and have the same lifetime as the + * TensorImpl. When the TensorImpl is destroyed, the Storage structure ensures + * proper cleanup of the associated metadata and data if needed. */ -struct TensorImplPtrDeleter final { - // A custom deleter of the std::shared_ptr is required to be copyable until - // C++20, so any data it holds must be copyable too. Hence, we use shared_ptr - // to hold the data and metadata to avoid unnecessary copies. - std::shared_ptr data; - std::shared_ptr> sizes; - std::shared_ptr> dim_order; - std::shared_ptr> strides; +struct Storage final { + exec_aten::TensorImpl tensor_impl; + std::vector sizes; + std::vector dim_order; + std::vector strides; + std::function deleter; - void operator()(exec_aten::TensorImpl* pointer) { - // Release all resources immediately since the data held by the - // TensorImplPtrDeleter is tied to the managed object, not the smart pointer - // itself. We need to free this memory when the object is destroyed, not - // when the smart pointer (and deleter) are eventually destroyed or reset. - data.reset(); - sizes.reset(); - dim_order.reset(); - strides.reset(); - delete pointer; + Storage( + exec_aten::TensorImpl&& tensor_impl, + std::vector&& sizes, + std::vector&& dim_order, + std::vector&& strides, + std::function&& deleter) + : tensor_impl(std::move(tensor_impl)), + sizes(std::move(sizes)), + dim_order(std::move(dim_order)), + strides(std::move(strides)), + deleter(std::move(deleter)) {} + + ~Storage() { + if (deleter) { + deleter(tensor_impl.mutable_data()); + } } }; #endif // USE_ATEN_LIB @@ -89,7 +89,7 @@ TensorImplPtr make_tensor_impl_ptr( strides = std::move(computed_strides); } #ifndef USE_ATEN_LIB - auto tensor_impl = std::make_unique( + exec_aten::TensorImpl tensor_impl( type, dim, sizes.data(), @@ -97,16 +97,15 @@ TensorImplPtr make_tensor_impl_ptr( dim_order.data(), strides.data(), dim > 0 ? dynamism : exec_aten::TensorShapeDynamism::STATIC); - return TensorImplPtr( - tensor_impl.release(), - TensorImplPtrDeleter{ - std::shared_ptr( - data, deleter ? std::move(deleter) : noop_deleter), - std::make_shared>(std::move(sizes)), - std::make_shared>( - std::move(dim_order)), - std::make_shared>( - std::move(strides))}); + auto storage = std::make_shared( + std::move(tensor_impl), + std::move(sizes), + std::move(dim_order), + std::move(strides), + std::move(deleter)); + const auto tensor_impl_ptr = &storage->tensor_impl; + return std::shared_ptr( + std::move(storage), tensor_impl_ptr); #else auto options = c10::TensorOptions() .dtype(c10::scalarTypeToTypeMeta(type)) @@ -139,16 +138,16 @@ TensorImplPtr make_tensor_impl_ptr( data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) * exec_aten::elementSize(type), "Data size is smaller than required by sizes and scalar type."); - auto raw_data_ptr = data.data(); - auto data_ptr = std::make_shared>(std::move(data)); + auto data_ptr = data.data(); return make_tensor_impl_ptr( std::move(sizes), - raw_data_ptr, + data_ptr, std::move(dim_order), std::move(strides), type, dynamism, - [data_ptr = std::move(data_ptr)](void*) {}); + // Data is moved into the deleter and is destroyed together with Storage. + [data = std::move(data)](void*) {}); } } // namespace extension