Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions runtime/core/evalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,23 @@ struct EValue {
new (&payload.as_tensor) exec_aten::Tensor(t);
}

// Template constructor that allows construction from types that can be
// dereferenced to produce a type that EValue can be implicitly constructed
// from.
template <typename T>
/*implicit*/ EValue(
T&& value,
typename std::enable_if<std::is_convertible<
decltype(*std::forward<T>(value)),
EValue>::value>::type* = 0) {
ET_CHECK_MSG(value != nullptr, "Pointer is null.");
*this = EValue(*std::forward<T>(value));
}

// Delete constructor for raw pointers to ensure they cannot be used.
template <typename T>
explicit EValue(T* value) = delete;

bool isTensor() const {
return tag == Tag::Tensor;
}
Expand Down
132 changes: 113 additions & 19 deletions runtime/core/test/evalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,67 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/core/evalue.h>

#include <gtest/gtest.h>

#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>

using namespace ::testing;

namespace torch {
namespace executor {

using exec_aten::ScalarType;
using executorch::runtime::BoxedEvalueList;
using executorch::runtime::EValue;
using executorch::runtime::Tag;
using executorch::runtime::testing::TensorFactory;

TEST(TestEValue, CopyTrivialType) {
class EValueTest : public ::testing::Test {
protected:
void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
runtime_init();
}
};

// An utility class used in tests to simulate objects that manage Tensors.
// The overloaded operator*() is used to return the underlying Tensor, mimicking
// behavior of smart pointers.
class TensorWrapper {
public:
explicit TensorWrapper(exec_aten::Tensor tensor)
: tensor_(std::make_unique<exec_aten::Tensor>(std::move(tensor))) {}

exec_aten::Tensor& operator*() {
return *tensor_;
}

const exec_aten::Tensor& operator*() const {
return *tensor_;
}

operator bool() const {
return static_cast<bool>(tensor_);
}

bool operator==(std::nullptr_t) const {
return tensor_ == nullptr;
}

bool operator!=(std::nullptr_t) const {
return tensor_ != nullptr;
}

private:
std::unique_ptr<exec_aten::Tensor> tensor_;
};

TEST_F(EValueTest, CopyTrivialType) {
EValue a;
EValue b(true);
EXPECT_TRUE(a.isNone());
Expand All @@ -30,7 +76,7 @@ TEST(TestEValue, CopyTrivialType) {
EXPECT_EQ(b.to<bool>(), true);
}

TEST(TestEValue, CopyTensor) {
TEST_F(EValueTest, CopyTensor) {
TensorFactory<ScalarType::Float> tf;
EValue a(tf.ones({3, 2}));
EValue b(tf.ones({1}));
Expand All @@ -39,7 +85,7 @@ TEST(TestEValue, CopyTensor) {
EXPECT_EQ(a.toTensor().dim(), 1);
}

TEST(TestEValue, TypeMismatchFatals) {
TEST_F(EValueTest, TypeMismatchFatals) {
ET_EXPECT_DEATH(
{
auto e = EValue(true);
Expand All @@ -48,12 +94,12 @@ TEST(TestEValue, TypeMismatchFatals) {
"");
}

TEST(TestEValue, NoneByDefault) {
TEST_F(EValueTest, NoneByDefault) {
EValue e;
EXPECT_TRUE(e.isNone());
}

TEST(TestEValue, ToOptionalInt) {
TEST_F(EValueTest, ToOptionalInt) {
EValue e((int64_t)5);
EXPECT_TRUE(e.isInt());
EXPECT_FALSE(e.isNone());
Expand All @@ -63,15 +109,15 @@ TEST(TestEValue, ToOptionalInt) {
EXPECT_EQ(o.value(), 5);
}

TEST(TestEValue, NoneToOptionalInt) {
TEST_F(EValueTest, NoneToOptionalInt) {
EValue e;
EXPECT_TRUE(e.isNone());

exec_aten::optional<int64_t> o = e.toOptional<int64_t>();
EXPECT_FALSE(o.has_value());
}

TEST(TestEValue, ToOptionalScalar) {
TEST_F(EValueTest, ToOptionalScalar) {
exec_aten::Scalar s((double)3.141);
EValue e(s);
EXPECT_TRUE(e.isScalar());
Expand All @@ -83,7 +129,7 @@ TEST(TestEValue, ToOptionalScalar) {
EXPECT_EQ(o.value().to<double>(), 3.141);
}

TEST(TESTEValue, ScalarToType) {
TEST_F(EValueTest, ScalarToType) {
exec_aten::Scalar s_d((double)3.141);
EXPECT_EQ(s_d.to<double>(), 3.141);
exec_aten::Scalar s_i((int64_t)3);
Expand All @@ -92,23 +138,23 @@ TEST(TESTEValue, ScalarToType) {
EXPECT_EQ(s_b.to<bool>(), true);
}

TEST(TestEValue, NoneToOptionalScalar) {
TEST_F(EValueTest, NoneToOptionalScalar) {
EValue e;
EXPECT_TRUE(e.isNone());

exec_aten::optional<exec_aten::Scalar> o = e.toOptional<exec_aten::Scalar>();
EXPECT_FALSE(o.has_value());
}

TEST(TestEValue, NoneToOptionalTensor) {
TEST_F(EValueTest, NoneToOptionalTensor) {
EValue e;
EXPECT_TRUE(e.isNone());

exec_aten::optional<exec_aten::Tensor> o = e.toOptional<exec_aten::Tensor>();
EXPECT_FALSE(o.has_value());
}

TEST(TestEValue, ToScalarType) {
TEST_F(EValueTest, ToScalarType) {
EValue e((int64_t)4);
auto o = e.toScalarType();
EXPECT_EQ(o, exec_aten::ScalarType::Long);
Expand All @@ -118,7 +164,7 @@ TEST(TestEValue, ToScalarType) {
EXPECT_EQ(o2.value(), exec_aten::ScalarType::Long);
}

TEST(TestEValue, toString) {
TEST_F(EValueTest, toString) {
const EValue e("foo", 3);
EXPECT_TRUE(e.isString());
EXPECT_FALSE(e.isNone());
Expand All @@ -127,28 +173,28 @@ TEST(TestEValue, toString) {
EXPECT_EQ(x, "foo");
}

TEST(TestEValue, MemoryFormat) {
TEST_F(EValueTest, MemoryFormat) {
const EValue e((int64_t)0);
EXPECT_TRUE(e.isInt());
const exec_aten::MemoryFormat m = e.to<exec_aten::MemoryFormat>();
EXPECT_EQ(m, exec_aten::MemoryFormat::Contiguous);
}

TEST(TestEValue, Layout) {
TEST_F(EValueTest, Layout) {
const EValue e((int64_t)0);
EXPECT_TRUE(e.isInt());
const exec_aten::Layout l = e.to<exec_aten::Layout>();
EXPECT_EQ(l, exec_aten::Layout::Strided);
}

TEST(TestEValue, Device) {
TEST_F(EValueTest, Device) {
const EValue e((int64_t)0);
EXPECT_TRUE(e.isInt());
const exec_aten::Device d = e.to<exec_aten::Device>();
EXPECT_TRUE(d.is_cpu());
}

TEST(TestEValue, BoxedEvalueList) {
TEST_F(EValueTest, BoxedEvalueList) {
// create fake values table to point to
EValue values[3] = {
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
Expand All @@ -164,7 +210,7 @@ TEST(TestEValue, BoxedEvalueList) {
EXPECT_EQ(unwrapped[2], 3);
}

TEST(TestEValue, toOptionalTensorList) {
TEST_F(EValueTest, toOptionalTensorList) {
// create list, empty evalue ctor gets tag::None
EValue values[2] = {EValue(), EValue()};
EValue* values_p[2] = {&values[0], &values[1]};
Expand All @@ -185,3 +231,51 @@ TEST(TestEValue, toOptionalTensorList) {
EXPECT_FALSE(x[0].has_value());
EXPECT_FALSE(x[1].has_value());
}

TEST_F(EValueTest, ConstructFromUniquePtr) {
TensorFactory<ScalarType::Float> tf;
auto tensor_ptr = std::make_unique<exec_aten::Tensor>(tf.ones({2, 3}));

EValue evalue(std::move(tensor_ptr));

EXPECT_TRUE(evalue.isTensor());
EXPECT_EQ(evalue.toTensor().dim(), 2);
EXPECT_EQ(evalue.toTensor().numel(), 6);

EValue evalue2(std::make_unique<exec_aten::Tensor>(tf.ones({4, 5})));

EXPECT_TRUE(evalue2.isTensor());
EXPECT_EQ(evalue2.toTensor().dim(), 2);
EXPECT_EQ(evalue2.toTensor().numel(), 20);
}

TEST_F(EValueTest, ConstructFromSharedPtr) {
TensorFactory<ScalarType::Float> tf;
auto tensor_ptr = std::make_shared<exec_aten::Tensor>(tf.ones({4, 5}));

EValue evalue(tensor_ptr);

EXPECT_TRUE(evalue.isTensor());
EXPECT_EQ(evalue.toTensor().dim(), 2);
EXPECT_EQ(evalue.toTensor().numel(), 20);
}

TEST_F(EValueTest, ConstructFromTensorWrapper) {
TensorFactory<ScalarType::Float> tf;
TensorWrapper tensor_wrapper(tf.ones({4, 5}));

EValue evalue(tensor_wrapper);

EXPECT_TRUE(evalue.isTensor());
EXPECT_EQ(evalue.toTensor().dim(), 2);
EXPECT_EQ(evalue.toTensor().numel(), 20);
}

TEST_F(EValueTest, ConstructFromNullPtrAborts) {
std::unique_ptr<exec_aten::Tensor> null_ptr;

ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
}

} // namespace executor
} // namespace torch