diff --git a/extension/tensor/tensor_impl_ptr.cpp b/extension/tensor/tensor_impl_ptr.cpp index 358acfd1850..c1afc4ab39b 100644 --- a/extension/tensor/tensor_impl_ptr.cpp +++ b/extension/tensor/tensor_impl_ptr.cpp @@ -15,43 +15,6 @@ namespace executorch { namespace extension { -namespace { -#ifndef USE_ATEN_LIB -// No-op deleter that does nothing when called. -static void noop_deleter(void*) {} - -/** - * Custom deleter for TensorImplPtr that ensures the memory associated with - * dynamic metadata (sizes, dim_order, and strides) is properly managed when the - * TensorImpl is destroyed. - * - * Since TensorImpl does not own the metadata arrays (sizes, dim_order, - * strides), this deleter is responsible for releasing that memory when the - * TensorImpl is destroyed. - */ -struct TensorImplPtrDeleter final { - // A custom deleter of the std::shared_ptr is required to be copyable until - // C++20, so any data it holds must be copyable too. Hence, we use shared_ptr - // to hold the data and metadata to avoid unnecessary copies. - std::shared_ptr data; - std::shared_ptr> sizes; - std::shared_ptr> dim_order; - std::shared_ptr> strides; - - void operator()(exec_aten::TensorImpl* pointer) { - // Release all resources immediately since the data held by the - // TensorImplPtrDeleter is tied to the managed object, not the smart pointer - // itself. We need to free this memory when the object is destroyed, not - // when the smart pointer (and deleter) are eventually destroyed or reset. - data.reset(); - sizes.reset(); - dim_order.reset(); - strides.reset(); - delete pointer; - } -}; -#endif // USE_ATEN_LIB -} // namespace TensorImplPtr make_tensor_impl_ptr( std::vector sizes, @@ -89,7 +52,7 @@ TensorImplPtr make_tensor_impl_ptr( strides = std::move(computed_strides); } #ifndef USE_ATEN_LIB - auto tensor_impl = std::make_unique( + exec_aten::TensorImpl tensor_impl( type, dim, sizes.data(), @@ -98,15 +61,11 @@ TensorImplPtr make_tensor_impl_ptr( strides.data(), dim > 0 ? dynamism : exec_aten::TensorShapeDynamism::STATIC); return TensorImplPtr( - tensor_impl.release(), - TensorImplPtrDeleter{ - std::shared_ptr( - data, deleter ? std::move(deleter) : noop_deleter), - std::make_shared>(std::move(sizes)), - std::make_shared>( - std::move(dim_order)), - std::make_shared>( - std::move(strides))}); + std::move(tensor_impl), + std::move(sizes), + std::move(dim_order), + std::move(strides), + std::move(deleter)); #else auto options = c10::TensorOptions() .dtype(c10::scalarTypeToTypeMeta(type)) diff --git a/extension/tensor/tensor_impl_ptr.h b/extension/tensor/tensor_impl_ptr.h index 89fc7ff1ebf..2c3a0d2b833 100644 --- a/extension/tensor/tensor_impl_ptr.h +++ b/extension/tensor/tensor_impl_ptr.h @@ -30,7 +30,99 @@ namespace extension { * It serves as a safer, more convenient alternative to the original TensorImpl, * which does not manage its metadata by design. */ -using TensorImplPtr = std::shared_ptr; +class TensorImplPtr { + public: + constexpr TensorImplPtr() = default; + explicit constexpr TensorImplPtr(std::nullptr_t) {} + TensorImplPtr( + exec_aten::TensorImpl tensor_impl, + std::vector sizes, + std::vector dim_order, + std::vector strides, + std::function data_deleter = nullptr) + : repr_(std::make_shared( + std::move(tensor_impl), + std::move(sizes), + std::move(dim_order), + std::move(strides), + std::move(data_deleter))) {} + + operator bool() const { + return static_cast(repr_); + } + + exec_aten::TensorImpl* get() const { + return repr_ ? &repr_->tensor_impl_ : nullptr; + } + + exec_aten::TensorImpl* operator->() const { + return get(); + } + + exec_aten::TensorImpl& operator*() const { + ET_DCHECK(repr_ != nullptr); + return *get(); + } + + void reset() { + repr_.reset(); + } + + void swap(TensorImplPtr& other) noexcept { + repr_.swap(other.repr_); + } + + bool operator==(const TensorImplPtr& rhs) const { + return repr_ == rhs.repr_; + } + + bool operator!=(const TensorImplPtr& rhs) const { + return !(*this == rhs); + } + + bool operator==(std::nullptr_t) const { + return !operator bool(); + } + + bool operator!=(std::nullptr_t) const { + return !(*this == nullptr); + } + + auto use_count() const noexcept { + return repr_.use_count(); + } + + private: + struct HeapData { + exec_aten::TensorImpl tensor_impl_; + // TODO: consolidate these allocations similar to torch::Tensor's + // SizesAndStrides? + std::vector sizes_; + std::vector dim_order_; + std::vector strides_; + // TODO: don't pay for the deleter if it wasn't set. + std::function data_deleter_; + + HeapData( + exec_aten::TensorImpl&& ti, + std::vector&& sizes, + std::vector&& dim_order, + std::vector&& strides, + std::function&& data_deleter) + : tensor_impl_(std::move(ti)), + sizes_(std::move(sizes)), + dim_order_(std::move(dim_order)), + strides_(std::move(strides)), + data_deleter_(std::move(data_deleter)) {} + + ~HeapData() { + if (data_deleter_) { + data_deleter_(tensor_impl_.mutable_data()); + } + } + }; + std::shared_ptr repr_; +}; #else /** * A smart pointer type for managing the lifecycle of a TensorImpl. diff --git a/extension/tensor/test/tensor_impl_ptr_test.cpp b/extension/tensor/test/tensor_impl_ptr_test.cpp index d3d827a4955..94f8d931519 100644 --- a/extension/tensor/test/tensor_impl_ptr_test.cpp +++ b/extension/tensor/test/tensor_impl_ptr_test.cpp @@ -23,6 +23,31 @@ class TensorImplPtrTest : public ::testing::Test { } }; +TEST_F(TensorImplPtrTest, BasicSmartPointerAccess) { + TensorImplPtr p; + EXPECT_FALSE(p); + EXPECT_EQ(p, nullptr); + TensorImplPtr p2 = make_tensor_impl_ptr({1}, nullptr); + EXPECT_TRUE(p2); + EXPECT_NE(p2, nullptr); + EXPECT_EQ(p2->dim(), 1); + EXPECT_EQ((*p2).dim(), 1); + EXPECT_NE(p, p2); + p2.reset(); + EXPECT_FALSE(p2); + EXPECT_EQ(p2, nullptr); + EXPECT_EQ(p, p2); +} + +TEST_F(TensorImplPtrTest, Swap) { + TensorImplPtr p; + TensorImplPtr p2 = make_tensor_impl_ptr({1}, nullptr); + p.swap(p2); + EXPECT_FALSE(p2); + EXPECT_TRUE(p); + EXPECT_EQ(p->dim(), 1); +} + TEST_F(TensorImplPtrTest, ScalarTensorCreation) { float scalar_data = 3.14f; auto tensor_impl = make_tensor_impl_ptr({}, &scalar_data);