From 6e4eab15e60f845609749f8f1f7b311bbccbbeeb Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 21 Apr 2026 16:56:36 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- runtime/core/evalue.h | 196 ++++++++++++++++++++++++- runtime/core/test/evalue_test.cpp | 228 ++++++++++++++++++++++++++++++ 2 files changed, 422 insertions(+), 2 deletions(-) diff --git a/runtime/core/evalue.h b/runtime/core/evalue.h index 8d75b1ace97..47d44698faf 100644 --- a/runtime/core/evalue.h +++ b/runtime/core/evalue.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -193,6 +194,13 @@ struct EValue { return payload.copyable_union.as_int; } + Result tryToInt() const { + if (!isInt()) { + return Error::InvalidType; + } + return payload.copyable_union.as_int; + } + /****** Double Type ******/ /*implicit*/ EValue(double d) : tag(Tag::Double) { payload.copyable_union.as_double = d; @@ -207,6 +215,13 @@ struct EValue { return payload.copyable_union.as_double; } + Result tryToDouble() const { + if (!isDouble()) { + return Error::InvalidType; + } + return payload.copyable_union.as_double; + } + /****** Bool Type ******/ /*implicit*/ EValue(bool b) : tag(Tag::Bool) { payload.copyable_union.as_bool = b; @@ -221,6 +236,13 @@ struct EValue { return payload.copyable_union.as_bool; } + Result tryToBool() const { + if (!isBool()) { + return Error::InvalidType; + } + return payload.copyable_union.as_bool; + } + /****** Scalar Type ******/ /// Construct an EValue using the implicit value of a Scalar. /*implicit*/ EValue(executorch::aten::Scalar s) { @@ -256,6 +278,19 @@ struct EValue { } } + Result tryToScalar() const { + if (isDouble()) { + return executorch::aten::Scalar(payload.copyable_union.as_double); + } + if (isInt()) { + return executorch::aten::Scalar(payload.copyable_union.as_int); + } + if (isBool()) { + return executorch::aten::Scalar(payload.copyable_union.as_bool); + } + return Error::InvalidType; + } + /****** Tensor Type ******/ /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) { // When built in aten mode, at::Tensor has a non trivial constructor @@ -305,6 +340,13 @@ struct EValue { return payload.as_tensor; } + Result tryToTensor() const { + if (!isTensor()) { + return Error::InvalidType; + } + return payload.as_tensor; + } + /****** String Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* s) : tag(Tag::String) { ET_CHECK_MSG(s != nullptr, "ArrayRef pointer cannot be null"); @@ -325,6 +367,18 @@ struct EValue { payload.copyable_union.as_string_ptr->size()); } + Result tryToString() const { + if (!isString()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_string_ptr == nullptr) { + return Error::InvalidState; + } + return std::string_view( + payload.copyable_union.as_string_ptr->data(), + payload.copyable_union.as_string_ptr->size()); + } + /****** Int List Type ******/ /*implicit*/ EValue(BoxedEvalueList* i) : tag(Tag::ListInt) { ET_CHECK_MSG( @@ -344,6 +398,16 @@ struct EValue { return (payload.copyable_union.as_int_list_ptr)->get(); } + Result> tryToIntList() const { + if (!isIntList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_int_list_ptr == nullptr) { + return Error::InvalidState; + } + return (payload.copyable_union.as_int_list_ptr)->get(); + } + /****** Bool List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* b) : tag(Tag::ListBool) { @@ -363,6 +427,16 @@ struct EValue { return *(payload.copyable_union.as_bool_list_ptr); } + Result> tryToBoolList() const { + if (!isBoolList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_bool_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_bool_list_ptr); + } + /****** Double List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* d) : tag(Tag::ListDouble) { @@ -382,6 +456,16 @@ struct EValue { return *(payload.copyable_union.as_double_list_ptr); } + Result> tryToDoubleList() const { + if (!isDoubleList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_double_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_double_list_ptr); + } + /****** Tensor List Type ******/ /*implicit*/ EValue(BoxedEvalueList* t) : tag(Tag::ListTensor) { @@ -402,6 +486,17 @@ struct EValue { return payload.copyable_union.as_tensor_list_ptr->get(); } + Result> tryToTensorList() + const { + if (!isTensorList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_tensor_list_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_tensor_list_ptr->get(); + } + /****** List Optional Tensor Type ******/ /*implicit*/ EValue( BoxedEvalueList>* t) @@ -426,6 +521,17 @@ struct EValue { return payload.copyable_union.as_list_optional_tensor_ptr->get(); } + Result>> + tryToListOptionalTensor() const { + if (!isListOptionalTensor()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_list_optional_tensor_ptr->get(); + } + /****** ScalarType Type ******/ executorch::aten::ScalarType toScalarType() const { ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); @@ -433,6 +539,14 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToScalarType() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** MemoryFormat Type ******/ executorch::aten::MemoryFormat toMemoryFormat() const { ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); @@ -440,12 +554,27 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToMemoryFormat() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** Layout Type ******/ executorch::aten::Layout toLayout() const { ET_CHECK_MSG(isInt(), "EValue is not a Layout."); return static_cast(payload.copyable_union.as_int); } + Result tryToLayout() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast(payload.copyable_union.as_int); + } + /****** Device Type ******/ executorch::aten::Device toDevice() const { ET_CHECK_MSG(isInt(), "EValue is not a Device."); @@ -455,6 +584,16 @@ struct EValue { -1); } + Result tryToDevice() const { + if (!isInt()) { + return Error::InvalidType; + } + return executorch::aten::Device( + static_cast( + payload.copyable_union.as_int), + -1); + } + template T to() &&; template @@ -462,6 +601,15 @@ struct EValue { template typename internal::evalue_to_ref_overload_return::type to() &; + /** + * Result-returning equivalent of `to()`. Returns `Error::InvalidType` on + * tag mismatch instead of aborting, so callers processing untrusted EValues + * (e.g., from a `.pte`) can surface the error rather than terminate. + * Specializations are defined below via `EVALUE_DEFINE_TRY_TO`. + */ + template + Result tryTo() const; + /** * Converts the EValue to an optional object that can represent both T and * an uninitialized state. @@ -474,6 +622,23 @@ struct EValue { return this->to(); } + /** + * Result-returning equivalent of `toOptional()`. None maps to an empty + * optional; any other tag that doesn't match T propagates `tryTo()`'s + * error (`Error::InvalidType`). + */ + template + inline Result> tryToOptional() const { + if (this->isNone()) { + return std::optional(executorch::aten::nullopt); + } + auto r = this->tryTo(); + if (!r.ok()) { + return r.error(); + } + return std::optional(std::move(r.get())); + } + private: // Pre cond: the payload value has had its destructor called void clearToNone() noexcept { @@ -524,7 +689,7 @@ struct EValue { #define EVALUE_DEFINE_TO(T, method_name) \ template <> \ - inline T EValue::to()&& { \ + inline T EValue::to() && { \ return static_cast(std::move(*this).method_name()); \ } \ template <> \ @@ -538,7 +703,7 @@ struct EValue { template <> \ inline ::executorch::runtime::internal::evalue_to_ref_overload_return< \ T>::type \ - EValue::to()& { \ + EValue::to() & { \ typedef ::executorch::runtime::internal::evalue_to_ref_overload_return< \ T>::type return_type; \ return static_cast(this->method_name()); \ @@ -591,6 +756,33 @@ EVALUE_DEFINE_TO( toListOptionalTensor) #undef EVALUE_DEFINE_TO +#define EVALUE_DEFINE_TRY_TO(T, method_name) \ + template <> \ + inline Result EValue::tryTo() const { \ + return this->method_name(); \ + } + +EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar) +EVALUE_DEFINE_TRY_TO(int64_t, tryToInt) +EVALUE_DEFINE_TRY_TO(bool, tryToBool) +EVALUE_DEFINE_TRY_TO(double, tryToDouble) +EVALUE_DEFINE_TRY_TO(std::string_view, tryToString) +EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType) +EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat) +EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout) +EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice) +EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor) +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToIntList) +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToDoubleList) +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToBoolList) +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef, + tryToTensorList) +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef>, + tryToListOptionalTensor) +#undef EVALUE_DEFINE_TRY_TO + template executorch::aten::ArrayRef BoxedEvalueList::get() const { for (typename executorch::aten::ArrayRef::size_type i = 0; diff --git a/runtime/core/test/evalue_test.cpp b/runtime/core/test/evalue_test.cpp index edf6a1b12c1..18060bd8c82 100644 --- a/runtime/core/test/evalue_test.cpp +++ b/runtime/core/test/evalue_test.cpp @@ -417,3 +417,231 @@ TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) { EXPECT_TRUE(e.isListOptionalTensor()); ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "pointer is null"); } + +TEST_F(EValueTest, TryToTensorSuccess) { + TensorFactory tf; + EValue e(tf.ones({3, 2})); + auto result = e.tryToTensor(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->dim(), 2); + EXPECT_EQ(result->numel(), 6); +} + +TEST_F(EValueTest, TryToTensorTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToTensor(); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToOptionalTensorSuccess) { + TensorFactory tf; + EValue e(tf.ones({3, 2})); + auto result = e.tryToOptional(); + EXPECT_TRUE(result.ok()); + EXPECT_TRUE(result->has_value()); + EXPECT_EQ(result->value().dim(), 2); +} + +TEST_F(EValueTest, TryToOptionalTensorNone) { + EValue e; + auto result = e.tryToOptional(); + EXPECT_TRUE(result.ok()); + EXPECT_FALSE(result->has_value()); +} + +TEST_F(EValueTest, TryToOptionalTensorTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToOptional(); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +// Scalar/primitive tryTo* coverage. Each test pair exercises the match and +// mismatch paths; the type-mismatch check uses an int64_t EValue (or a Tensor +// EValue when testing tryToInt) to guarantee the tag disagrees. + +TEST_F(EValueTest, TryToIntSuccess) { + EValue e(static_cast(42)); + auto result = e.tryToInt(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.get(), 42); +} + +TEST_F(EValueTest, TryToIntTypeMismatch) { + EValue e(3.14); + auto result = e.tryToInt(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToDoubleSuccess) { + EValue e(3.14); + auto result = e.tryToDouble(); + EXPECT_TRUE(result.ok()); + EXPECT_DOUBLE_EQ(result.get(), 3.14); +} + +TEST_F(EValueTest, TryToDoubleTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToDouble(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToBoolSuccess) { + EValue e(true); + auto result = e.tryToBool(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.get(), true); +} + +TEST_F(EValueTest, TryToBoolTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToBool(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToScalarFromInt) { + EValue e(static_cast(7)); + auto result = e.tryToScalar(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->to(), 7); +} + +TEST_F(EValueTest, TryToScalarFromDouble) { + EValue e(2.5); + auto result = e.tryToScalar(); + EXPECT_TRUE(result.ok()); + EXPECT_DOUBLE_EQ(result->to(), 2.5); +} + +TEST_F(EValueTest, TryToScalarFromBool) { + EValue e(true); + auto result = e.tryToScalar(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->to(), true); +} + +TEST_F(EValueTest, TryToScalarNoneTag) { + // None is neither Int/Double/Bool, so tryToScalar must reject it. + EValue e; + auto result = e.tryToScalar(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToScalarTypeTagReturnsScalarType) { + // ScalarType/MemoryFormat/Layout/Device share the Int tag; exercise each. + EValue e(static_cast(static_cast(ScalarType::Float))); + auto st = e.tryToScalarType(); + EXPECT_TRUE(st.ok()); + EXPECT_EQ(st.get(), ScalarType::Float); +} + +TEST_F(EValueTest, TryToScalarTypeTypeMismatch) { + EValue e(3.14); + auto result = e.tryToScalarType(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToMemoryFormatTypeMismatch) { + EValue e(3.14); + auto result = e.tryToMemoryFormat(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToLayoutTypeMismatch) { + EValue e(3.14); + auto result = e.tryToLayout(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToDeviceTypeMismatch) { + EValue e(3.14); + auto result = e.tryToDevice(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +// List tryTo* — mismatch paths. Success paths require building a +// BoxedEvalueList or ArrayRef host object, which the non-list tests above +// already cover via the shared tag-check logic; one mismatch test per list +// type is enough to exercise the added code. + +TEST_F(EValueTest, TryToIntListTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToIntList(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToDoubleListTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToDoubleList(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToBoolListTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToBoolList(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensorListTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToTensorList(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToListOptionalTensorTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToListOptionalTensor(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToStringTypeMismatch) { + EValue e(static_cast(42)); + auto result = e.tryToString(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +// Templated tryTo() dispatcher. Matches and mismatches should behave +// identically to the named tryToX methods. + +TEST_F(EValueTest, TryToTemplateIntSuccess) { + EValue e(static_cast(42)); + auto result = e.tryTo(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result.get(), 42); +} + +TEST_F(EValueTest, TryToTemplateIntMismatch) { + EValue e(3.14); + auto result = e.tryTo(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, TryToTemplateTensorSuccess) { + TensorFactory tf; + EValue e(tf.ones({3, 2})); + auto result = e.tryTo(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->numel(), 6); +} + +TEST_F(EValueTest, TryToOptionalIntSuccess) { + EValue e(static_cast(42)); + auto result = e.tryToOptional(); + EXPECT_TRUE(result.ok()); + EXPECT_TRUE(result->has_value()); + EXPECT_EQ(result->value(), 42); +} + +TEST_F(EValueTest, TryToOptionalIntNone) { + EValue e; + auto result = e.tryToOptional(); + EXPECT_TRUE(result.ok()); + EXPECT_FALSE(result->has_value()); +} + +TEST_F(EValueTest, TryToOptionalIntTypeMismatch) { + EValue e(3.14); + auto result = e.tryToOptional(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +}