Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 58 additions & 28 deletions extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,6 @@ namespace executorch {
namespace extension {

#ifndef USE_ATEN_LIB
namespace internal {
/**
* Custom deleter for TensorPtr that ensures proper management of the associated
* TensorImplPtr.
*
* Since Tensor does not own its TensorImpl, this deleter manages the
* TensorImplPtr lifecycle, ensuring dynamic metadata (sizes, dim_order,
* strides) is released appropriately when the Tensor is destroyed.
*/
struct TensorPtrDeleter final {
TensorImplPtr tensor_impl;

void operator()(exec_aten::Tensor* pointer) {
// Release all resources immediately since the data held by the
// TensorPtrDeleter is tied to the managed object, not the smart pointer
// itself. We need to free this memory when the object is destroyed, not
// when the smart pointer (and deleter) are eventually destroyed or reset.
tensor_impl.reset();
delete pointer;
}
};
} // namespace internal

/**
* A smart pointer for managing the lifecycle of a Tensor.
Expand All @@ -47,8 +25,62 @@ struct TensorPtrDeleter final {
* by design. It ensures that the underlying TensorImpl can be safely shared
* among tensors as needed.
*/
using TensorPtr =
std::unique_ptr<exec_aten::Tensor, internal::TensorPtrDeleter>;
class TensorPtr : private std::unique_ptr<exec_aten::Tensor> {
public:
using unique_ptr::element_type;
using unique_ptr::get;
using unique_ptr::pointer;
using unique_ptr::operator bool;
using unique_ptr::operator*;
using unique_ptr::operator->;

constexpr TensorPtr() = default;
constexpr TensorPtr(std::nullptr_t) {}
~TensorPtr() = default;
TensorPtr(TensorPtr&& rhs) noexcept = default;
TensorPtr& operator=(TensorPtr&& rhs) noexcept = default;

TensorPtr(TensorImplPtr p)
: unique_ptr(std::make_unique<exec_aten::Tensor>(p.get())),
tensor_impl(std::move(p)) {}

// release() does not make sense as it separates the Tensor from the
// TensorImplPtr. using unique_ptr::release;

void reset() {
unique_ptr::reset();
tensor_impl.reset();
}

void swap(TensorPtr& other) {
unique_ptr::swap(static_cast<unique_ptr&>(other));
tensor_impl.swap(other.tensor_impl);
}

bool operator==(const TensorPtr& rhs) const {
// No need to check tensor_impl; if the Tensor pointers are equal
// then the tensor_impls must also be since this is a unique
// pointer.
return static_cast<const unique_ptr&>(*this) ==
static_cast<const unique_ptr&>(rhs);
}

bool operator!=(const TensorPtr& rhs) const {
return !(*this == rhs);
}

bool operator==(std::nullptr_t) const {
return static_cast<const unique_ptr&>(*this) == nullptr;
}

bool operator!=(std::nullptr_t) const {
return !(*this == nullptr);
}

private:
friend TensorPtr make_tensor_ptr(const TensorPtr& tensor);
TensorImplPtr tensor_impl;
};
#else
/**
* A smart pointer type for managing the lifecycle of a Tensor.
Expand All @@ -74,9 +106,7 @@ using TensorPtr = std::unique_ptr<exec_aten::Tensor>;
*/
inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
#ifndef USE_ATEN_LIB
auto tensor = std::make_unique<exec_aten::Tensor>(tensor_impl.get());
return TensorPtr(
tensor.release(), internal::TensorPtrDeleter{std::move(tensor_impl)});
return TensorPtr(std::move(tensor_impl));
#else
return std::make_unique<exec_aten::Tensor>(std::move(tensor_impl));
#endif // USE_ATEN_LIB
Expand All @@ -96,7 +126,7 @@ inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
*/
inline TensorPtr make_tensor_ptr(const TensorPtr& tensor) {
#ifndef USE_ATEN_LIB
return make_tensor_ptr(tensor.get_deleter().tensor_impl);
return make_tensor_ptr(tensor.tensor_impl);
#else
return make_tensor_ptr(tensor->getIntrusivePtr());
#endif // USE_ATEN_LIB
Expand Down
Loading