diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 1ec73882573..a6b377b2ef1 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -15,28 +15,6 @@ namespace executorch { namespace extension { #ifndef USE_ATEN_LIB -namespace internal { -/** - * Custom deleter for TensorPtr that ensures proper management of the associated - * TensorImplPtr. - * - * Since Tensor does not own its TensorImpl, this deleter manages the - * TensorImplPtr lifecycle, ensuring dynamic metadata (sizes, dim_order, - * strides) is released appropriately when the Tensor is destroyed. - */ -struct TensorPtrDeleter final { - TensorImplPtr tensor_impl; - - void operator()(exec_aten::Tensor* pointer) { - // Release all resources immediately since the data held by the - // TensorPtrDeleter is tied to the managed object, not the smart pointer - // itself. We need to free this memory when the object is destroyed, not - // when the smart pointer (and deleter) are eventually destroyed or reset. - tensor_impl.reset(); - delete pointer; - } -}; -} // namespace internal /** * A smart pointer for managing the lifecycle of a Tensor. @@ -47,8 +25,62 @@ struct TensorPtrDeleter final { * by design. It ensures that the underlying TensorImpl can be safely shared * among tensors as needed. */ -using TensorPtr = - std::unique_ptr; +class TensorPtr : private std::unique_ptr { + public: + using unique_ptr::element_type; + using unique_ptr::get; + using unique_ptr::pointer; + using unique_ptr::operator bool; + using unique_ptr::operator*; + using unique_ptr::operator->; + + constexpr TensorPtr() = default; + constexpr TensorPtr(std::nullptr_t) {} + ~TensorPtr() = default; + TensorPtr(TensorPtr&& rhs) noexcept = default; + TensorPtr& operator=(TensorPtr&& rhs) noexcept = default; + + TensorPtr(TensorImplPtr p) + : unique_ptr(std::make_unique(p.get())), + tensor_impl(std::move(p)) {} + + // release() does not make sense as it separates the Tensor from the + // TensorImplPtr. using unique_ptr::release; + + void reset() { + unique_ptr::reset(); + tensor_impl.reset(); + } + + void swap(TensorPtr& other) { + unique_ptr::swap(static_cast(other)); + tensor_impl.swap(other.tensor_impl); + } + + bool operator==(const TensorPtr& rhs) const { + // No need to check tensor_impl; if the Tensor pointers are equal + // then the tensor_impls must also be since this is a unique + // pointer. + return static_cast(*this) == + static_cast(rhs); + } + + bool operator!=(const TensorPtr& rhs) const { + return !(*this == rhs); + } + + bool operator==(std::nullptr_t) const { + return static_cast(*this) == nullptr; + } + + bool operator!=(std::nullptr_t) const { + return !(*this == nullptr); + } + + private: + friend TensorPtr make_tensor_ptr(const TensorPtr& tensor); + TensorImplPtr tensor_impl; +}; #else /** * A smart pointer type for managing the lifecycle of a Tensor. @@ -74,9 +106,7 @@ using TensorPtr = std::unique_ptr; */ inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) { #ifndef USE_ATEN_LIB - auto tensor = std::make_unique(tensor_impl.get()); - return TensorPtr( - tensor.release(), internal::TensorPtrDeleter{std::move(tensor_impl)}); + return TensorPtr(std::move(tensor_impl)); #else return std::make_unique(std::move(tensor_impl)); #endif // USE_ATEN_LIB @@ -96,7 +126,7 @@ inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) { */ inline TensorPtr make_tensor_ptr(const TensorPtr& tensor) { #ifndef USE_ATEN_LIB - return make_tensor_ptr(tensor.get_deleter().tensor_impl); + return make_tensor_ptr(tensor.tensor_impl); #else return make_tensor_ptr(tensor->getIntrusivePtr()); #endif // USE_ATEN_LIB