Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,79 @@

#include <executorch/extension/tensor/tensor_impl_ptr.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/platform/assert.h>

namespace executorch {
namespace extension {

#ifndef USE_ATEN_LIB
namespace internal {

/**
* Custom deleter for TensorPtr that ensures proper management of the associated
* TensorImplPtr.
*
* Since Tensor does not own its TensorImpl, this deleter manages the
* TensorImplPtr lifecycle, ensuring dynamic metadata (sizes, dim_order,
* strides) is released appropriately when the Tensor is destroyed.
* A smart pointer to a Tensor that owns and reference-counts its
* underlying TensorImpl, like torch::Tensor.
*/
struct TensorPtrDeleter final {
TensorImplPtr tensor_impl;

void operator()(exec_aten::Tensor* pointer) {
// Release all resources immediately since the data held by the
// TensorPtrDeleter is tied to the managed object, not the smart pointer
// itself. We need to free this memory when the object is destroyed, not
// when the smart pointer (and deleter) are eventually destroyed or reset.
tensor_impl.reset();
delete pointer;
class TensorPtr {
public:
constexpr TensorPtr() = default;
explicit constexpr TensorPtr(std::nullptr_t) {}
~TensorPtr() = default;
TensorPtr(TensorPtr&& rhs) noexcept = default;
TensorPtr& operator=(TensorPtr&& rhs) noexcept = default;

explicit TensorPtr(TensorImplPtr p)
: tensor_(p.get()), tensor_impl_(std::move(p)) {}

operator bool() const {
return static_cast<bool>(tensor_impl_);
}
};
} // namespace internal

/**
* A smart pointer for managing the lifecycle of a Tensor.
*
* TensorPtr uses a unique pointer to ensure each Tensor object has distinct
* ownership. This abstraction simplifies memory management and serves as a
* safer alternative to the standard Tensor, which does not manage its metadata
* by design. It ensures that the underlying TensorImpl can be safely shared
* among tensors as needed.
*/
using TensorPtr =
std::unique_ptr<exec_aten::Tensor, internal::TensorPtrDeleter>;
exec_aten::Tensor* get() const {
return tensor_impl_ ? &tensor_ : nullptr;
}

exec_aten::Tensor* operator->() const {
return get();
}

exec_aten::Tensor& operator*() const {
ET_DCHECK(*this != nullptr);
return *get();
}

void reset() {
tensor_ = exec_aten::Tensor(nullptr);
tensor_impl_.reset();
}

void swap(TensorPtr& other) noexcept {
std::swap(tensor_, other.tensor_);
std::swap(tensor_impl_, other.tensor_impl_);
}

bool operator==(const TensorPtr& rhs) const {
ET_DCHECK(
(tensor_.unsafeGetTensorImpl() == rhs.tensor_.unsafeGetTensorImpl()) ==
(tensor_impl_ == rhs.tensor_impl_));
return tensor_impl_ == rhs.tensor_impl_;
}

bool operator!=(const TensorPtr& rhs) const {
return !(*this == rhs);
}

bool operator==(std::nullptr_t) const {
return !operator bool();
}

bool operator!=(std::nullptr_t) const {
return !(*this == nullptr);
}

private:
friend TensorPtr make_tensor_ptr(const TensorPtr& tensor);
mutable exec_aten::Tensor tensor_{nullptr};
TensorImplPtr tensor_impl_;
Comment on lines +83 to +84
Copy link
Contributor Author

@swolchok swolchok Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that tensor_ and tensor_impl_ are redundant -- under the hood, they are both just TensorImpl* that point to the same thing. PyTorch core has also wrestled with this problem. IIRC I thought I finally cracked it within the past year or so, but never committed the PR because I didn't have a clear reason; I will try to dig it up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of pytorch/pytorch#95418 . We don't have the same problem here; rather than grafting not-reference-counting onto a reference-counting Tensor, we want to graft reference counting onto a not-reference-counting Tensor. I'll give it some more thought.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is solvable if we are able to remove get() and rely on -Waddress-of-temporary. diffs coming.

};
#else
/**
* A smart pointer type for managing the lifecycle of a Tensor.
Expand All @@ -74,9 +108,7 @@ using TensorPtr = std::unique_ptr<exec_aten::Tensor>;
*/
inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
#ifndef USE_ATEN_LIB
auto tensor = std::make_unique<exec_aten::Tensor>(tensor_impl.get());
return TensorPtr(
tensor.release(), internal::TensorPtrDeleter{std::move(tensor_impl)});
return TensorPtr(std::move(tensor_impl));
#else
return std::make_unique<exec_aten::Tensor>(std::move(tensor_impl));
#endif // USE_ATEN_LIB
Expand All @@ -96,7 +128,7 @@ inline TensorPtr make_tensor_ptr(TensorImplPtr tensor_impl) {
*/
inline TensorPtr make_tensor_ptr(const TensorPtr& tensor) {
#ifndef USE_ATEN_LIB
return make_tensor_ptr(tensor.get_deleter().tensor_impl);
return make_tensor_ptr(tensor.tensor_impl_);
#else
return make_tensor_ptr(tensor->getIntrusivePtr());
#endif // USE_ATEN_LIB
Expand Down
83 changes: 83 additions & 0 deletions extension/tensor/test/tensor_ptr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,89 @@ class TensorPtrTest : public ::testing::Test {
}
};

TEST_F(TensorPtrTest, BasicSmartPointerAccess) {
TensorPtr p;
EXPECT_FALSE(p);
EXPECT_EQ(p, nullptr);
EXPECT_EQ(p.get(), nullptr);
EXPECT_EQ(p.operator->(), nullptr);
TensorPtr p2 = make_tensor_ptr({1}, nullptr, {}, {});
EXPECT_TRUE(p2);
EXPECT_NE(p2, nullptr);
ASSERT_NE(p2.get(), nullptr);
ASSERT_NE(p2.operator->(), nullptr);
EXPECT_EQ(p2.get(), p2.operator->());
EXPECT_EQ(p2->dim(), 1);
EXPECT_EQ((*p2).dim(), 1);
EXPECT_NE(p, p2);
p2.reset();
EXPECT_FALSE(p2);
EXPECT_EQ(p2, nullptr);
EXPECT_EQ(p2.get(), nullptr);
EXPECT_EQ(p2.operator->(), nullptr);
EXPECT_EQ(p, p2);
}

TEST_F(TensorPtrTest, Swap) {
TensorPtr p;
TensorPtr p2 = make_tensor_ptr({1}, nullptr, {}, {});
p.swap(p2);
EXPECT_FALSE(p2);
EXPECT_TRUE(p);
EXPECT_EQ(p->dim(), 1);
}

TEST_F(TensorPtrTest, MoveConstruction) {
TensorPtr empty;
TensorPtr emptyMoved(std::move(empty));
EXPECT_FALSE(empty); // NOLINT(bugprone-use-after-move)
EXPECT_FALSE(emptyMoved);

TensorPtr notEmpty = make_tensor_ptr({1}, nullptr, {}, {});
TensorPtr notEmptyMoved(std::move(notEmpty));
EXPECT_FALSE(notEmpty); // NOLINT(bugprone-use-after-move)
EXPECT_TRUE(notEmptyMoved);
EXPECT_EQ(notEmptyMoved->dim(), 1);
}

TEST_F(TensorPtrTest, MoveAssignment) {
{
TensorPtr empty, emptyMoved;

emptyMoved = std::move(empty);
EXPECT_FALSE(empty); // NOLINT(bugprone-use-after-move)
EXPECT_FALSE(emptyMoved);
}

{
TensorPtr empty;
TensorPtr emptyMoved = make_tensor_ptr({1}, nullptr, {}, {});
emptyMoved = std::move(empty);
EXPECT_FALSE(empty); // NOLINT(bugprone-use-after-move)
EXPECT_FALSE(emptyMoved);
}

{
TensorPtr full = make_tensor_ptr({1}, nullptr, {}, {});
TensorPtr fullMoved;

fullMoved = std::move(full);
EXPECT_FALSE(full); // NOLINT(bugprone-use-after-move)
EXPECT_TRUE(fullMoved);
EXPECT_EQ(fullMoved->dim(), 1);
}

{
TensorPtr full = make_tensor_ptr({1}, nullptr, {}, {});
TensorPtr fullMoved = make_tensor_ptr({2, 2}, nullptr, {}, {});

fullMoved = std::move(full);
EXPECT_FALSE(full); // NOLINT(bugprone-use-after-move)
EXPECT_TRUE(fullMoved);
EXPECT_EQ(fullMoved->dim(), 1);
}
}

TEST_F(TensorPtrTest, ScalarTensorCreation) {
float scalar_data = 3.14f;
auto tensor = make_tensor_ptr({}, &scalar_data);
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/portable_type/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Tensor {
using StridesType = TensorImpl::StridesType;

Tensor() = delete;
explicit Tensor(TensorImpl* impl) : impl_(impl) {}
explicit constexpr Tensor(TensorImpl* impl) : impl_(impl) {}

/**
* Returns a pointer to the underlying TensorImpl.
Expand Down
Loading