From 9f33a182ad301cde22a97fd8ac14a5ce1b4a5d59 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Mon, 26 Aug 2024 18:33:26 -0700 Subject: [PATCH] Allow EValue to be constructed with a smart pointer implicitly. (#4902) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4902 Consider https://github.com/pytorch/executorch/blob/dc66414c70dbec763fa37aba7908d50373299435/examples/models/llava/runner/llava_image_prefiller.h#L31-L33, if the ManagedTensor had an operator* overload, it could have been passed to module.execute directly like `module_->execute(kImageEncoderMethod, managed_images)`. Reviewed By: dbort Differential Revision: D61783902 --- runtime/core/evalue.h | 17 ++++ runtime/core/test/evalue_test.cpp | 132 +++++++++++++++++++++++++----- 2 files changed, 130 insertions(+), 19 deletions(-) diff --git a/runtime/core/evalue.h b/runtime/core/evalue.h index 8aee5f399df..c0c534e0692 100644 --- a/runtime/core/evalue.h +++ b/runtime/core/evalue.h @@ -238,6 +238,23 @@ struct EValue { new (&payload.as_tensor) exec_aten::Tensor(t); } + // Template constructor that allows construction from types that can be + // dereferenced to produce a type that EValue can be implicitly constructed + // from. + template + /*implicit*/ EValue( + T&& value, + typename std::enable_if(value)), + EValue>::value>::type* = 0) { + ET_CHECK_MSG(value != nullptr, "Pointer is null."); + *this = EValue(*std::forward(value)); + } + + // Delete constructor for raw pointers to ensure they cannot be used. + template + explicit EValue(T* value) = delete; + bool isTensor() const { return tag == Tag::Tensor; } diff --git a/runtime/core/test/evalue_test.cpp b/runtime/core/test/evalue_test.cpp index bc3e3a7913b..4c08695dc4b 100644 --- a/runtime/core/test/evalue_test.cpp +++ b/runtime/core/test/evalue_test.cpp @@ -6,21 +6,67 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include -#include -#include #include +#include #include using namespace ::testing; + +namespace torch { +namespace executor { + using exec_aten::ScalarType; using executorch::runtime::BoxedEvalueList; using executorch::runtime::EValue; using executorch::runtime::Tag; using executorch::runtime::testing::TensorFactory; -TEST(TestEValue, CopyTrivialType) { +class EValueTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + runtime_init(); + } +}; + +// An utility class used in tests to simulate objects that manage Tensors. +// The overloaded operator*() is used to return the underlying Tensor, mimicking +// behavior of smart pointers. +class TensorWrapper { + public: + explicit TensorWrapper(exec_aten::Tensor tensor) + : tensor_(std::make_unique(std::move(tensor))) {} + + exec_aten::Tensor& operator*() { + return *tensor_; + } + + const exec_aten::Tensor& operator*() const { + return *tensor_; + } + + operator bool() const { + return static_cast(tensor_); + } + + bool operator==(std::nullptr_t) const { + return tensor_ == nullptr; + } + + bool operator!=(std::nullptr_t) const { + return tensor_ != nullptr; + } + + private: + std::unique_ptr tensor_; +}; + +TEST_F(EValueTest, CopyTrivialType) { EValue a; EValue b(true); EXPECT_TRUE(a.isNone()); @@ -30,7 +76,7 @@ TEST(TestEValue, CopyTrivialType) { EXPECT_EQ(b.to(), true); } -TEST(TestEValue, CopyTensor) { +TEST_F(EValueTest, CopyTensor) { TensorFactory tf; EValue a(tf.ones({3, 2})); EValue b(tf.ones({1})); @@ -39,7 +85,7 @@ TEST(TestEValue, CopyTensor) { EXPECT_EQ(a.toTensor().dim(), 1); } -TEST(TestEValue, TypeMismatchFatals) { +TEST_F(EValueTest, TypeMismatchFatals) { ET_EXPECT_DEATH( { auto e = EValue(true); @@ -48,12 +94,12 @@ TEST(TestEValue, TypeMismatchFatals) { ""); } -TEST(TestEValue, NoneByDefault) { +TEST_F(EValueTest, NoneByDefault) { EValue e; EXPECT_TRUE(e.isNone()); } -TEST(TestEValue, ToOptionalInt) { +TEST_F(EValueTest, ToOptionalInt) { EValue e((int64_t)5); EXPECT_TRUE(e.isInt()); EXPECT_FALSE(e.isNone()); @@ -63,7 +109,7 @@ TEST(TestEValue, ToOptionalInt) { EXPECT_EQ(o.value(), 5); } -TEST(TestEValue, NoneToOptionalInt) { +TEST_F(EValueTest, NoneToOptionalInt) { EValue e; EXPECT_TRUE(e.isNone()); @@ -71,7 +117,7 @@ TEST(TestEValue, NoneToOptionalInt) { EXPECT_FALSE(o.has_value()); } -TEST(TestEValue, ToOptionalScalar) { +TEST_F(EValueTest, ToOptionalScalar) { exec_aten::Scalar s((double)3.141); EValue e(s); EXPECT_TRUE(e.isScalar()); @@ -83,7 +129,7 @@ TEST(TestEValue, ToOptionalScalar) { EXPECT_EQ(o.value().to(), 3.141); } -TEST(TESTEValue, ScalarToType) { +TEST_F(EValueTest, ScalarToType) { exec_aten::Scalar s_d((double)3.141); EXPECT_EQ(s_d.to(), 3.141); exec_aten::Scalar s_i((int64_t)3); @@ -92,7 +138,7 @@ TEST(TESTEValue, ScalarToType) { EXPECT_EQ(s_b.to(), true); } -TEST(TestEValue, NoneToOptionalScalar) { +TEST_F(EValueTest, NoneToOptionalScalar) { EValue e; EXPECT_TRUE(e.isNone()); @@ -100,7 +146,7 @@ TEST(TestEValue, NoneToOptionalScalar) { EXPECT_FALSE(o.has_value()); } -TEST(TestEValue, NoneToOptionalTensor) { +TEST_F(EValueTest, NoneToOptionalTensor) { EValue e; EXPECT_TRUE(e.isNone()); @@ -108,7 +154,7 @@ TEST(TestEValue, NoneToOptionalTensor) { EXPECT_FALSE(o.has_value()); } -TEST(TestEValue, ToScalarType) { +TEST_F(EValueTest, ToScalarType) { EValue e((int64_t)4); auto o = e.toScalarType(); EXPECT_EQ(o, exec_aten::ScalarType::Long); @@ -118,7 +164,7 @@ TEST(TestEValue, ToScalarType) { EXPECT_EQ(o2.value(), exec_aten::ScalarType::Long); } -TEST(TestEValue, toString) { +TEST_F(EValueTest, toString) { const EValue e("foo", 3); EXPECT_TRUE(e.isString()); EXPECT_FALSE(e.isNone()); @@ -127,28 +173,28 @@ TEST(TestEValue, toString) { EXPECT_EQ(x, "foo"); } -TEST(TestEValue, MemoryFormat) { +TEST_F(EValueTest, MemoryFormat) { const EValue e((int64_t)0); EXPECT_TRUE(e.isInt()); const exec_aten::MemoryFormat m = e.to(); EXPECT_EQ(m, exec_aten::MemoryFormat::Contiguous); } -TEST(TestEValue, Layout) { +TEST_F(EValueTest, Layout) { const EValue e((int64_t)0); EXPECT_TRUE(e.isInt()); const exec_aten::Layout l = e.to(); EXPECT_EQ(l, exec_aten::Layout::Strided); } -TEST(TestEValue, Device) { +TEST_F(EValueTest, Device) { const EValue e((int64_t)0); EXPECT_TRUE(e.isInt()); const exec_aten::Device d = e.to(); EXPECT_TRUE(d.is_cpu()); } -TEST(TestEValue, BoxedEvalueList) { +TEST_F(EValueTest, BoxedEvalueList) { // create fake values table to point to EValue values[3] = { EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)}; @@ -164,7 +210,7 @@ TEST(TestEValue, BoxedEvalueList) { EXPECT_EQ(unwrapped[2], 3); } -TEST(TestEValue, toOptionalTensorList) { +TEST_F(EValueTest, toOptionalTensorList) { // create list, empty evalue ctor gets tag::None EValue values[2] = {EValue(), EValue()}; EValue* values_p[2] = {&values[0], &values[1]}; @@ -185,3 +231,51 @@ TEST(TestEValue, toOptionalTensorList) { EXPECT_FALSE(x[0].has_value()); EXPECT_FALSE(x[1].has_value()); } + +TEST_F(EValueTest, ConstructFromUniquePtr) { + TensorFactory tf; + auto tensor_ptr = std::make_unique(tf.ones({2, 3})); + + EValue evalue(std::move(tensor_ptr)); + + EXPECT_TRUE(evalue.isTensor()); + EXPECT_EQ(evalue.toTensor().dim(), 2); + EXPECT_EQ(evalue.toTensor().numel(), 6); + + EValue evalue2(std::make_unique(tf.ones({4, 5}))); + + EXPECT_TRUE(evalue2.isTensor()); + EXPECT_EQ(evalue2.toTensor().dim(), 2); + EXPECT_EQ(evalue2.toTensor().numel(), 20); +} + +TEST_F(EValueTest, ConstructFromSharedPtr) { + TensorFactory tf; + auto tensor_ptr = std::make_shared(tf.ones({4, 5})); + + EValue evalue(tensor_ptr); + + EXPECT_TRUE(evalue.isTensor()); + EXPECT_EQ(evalue.toTensor().dim(), 2); + EXPECT_EQ(evalue.toTensor().numel(), 20); +} + +TEST_F(EValueTest, ConstructFromTensorWrapper) { + TensorFactory tf; + TensorWrapper tensor_wrapper(tf.ones({4, 5})); + + EValue evalue(tensor_wrapper); + + EXPECT_TRUE(evalue.isTensor()); + EXPECT_EQ(evalue.toTensor().dim(), 2); + EXPECT_EQ(evalue.toTensor().numel(), 20); +} + +TEST_F(EValueTest, ConstructFromNullPtrAborts) { + std::unique_ptr null_ptr; + + ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, ""); +} + +} // namespace executor +} // namespace torch