Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions kernels/prim_ops/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ namespace function {

namespace {

#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
else { \
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
else { \
ET_KERNEL_CHECK_MSG( \
context, \
false, \
InvalidType, \
/* void */, \
"%zu, %zu", \
(size_t)a.tag, \
(size_t)b.tag); \
}

// TODO Fail using runtime context
#define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
(void)context; \
EValue& a = *stack[0]; \
Expand Down Expand Up @@ -168,8 +174,14 @@ static Kernel prim_ops[] = {
} else if (a.isDouble() && b.isInt()) {
floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"%zu, %zu",
(size_t)a.tag,
(size_t)b.tag);
}
}),

Expand All @@ -193,8 +205,14 @@ static Kernel prim_ops[] = {
} else if (a.isDouble() && b.isInt()) {
out = EValue(a.toDouble() / b.toInt());
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"%zu, %zu",
(size_t)a.tag,
(size_t)b.tag);
}
}),

Expand All @@ -214,8 +232,8 @@ static Kernel prim_ops[] = {
// TODO: This should be impossible
out = EValue(a.toDouble());
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
ET_KERNEL_CHECK_MSG(
context, false, InvalidType, /* void */, "%zu", (size_t)a.tag);
}
}),

Expand Down Expand Up @@ -265,8 +283,8 @@ static Kernel prim_ops[] = {
} else if (a.isDouble()) {
out = EValue(-a.toDouble());
} else {
// TODO Fail using runtime context
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
ET_KERNEL_CHECK_MSG(
context, false, InvalidType, /* void */, "%zu", (size_t)a.tag);
}
}),

Expand Down Expand Up @@ -303,7 +321,14 @@ static Kernel prim_ops[] = {
if (a.isInt() && b.isInt()) {
out = EValue(a.toInt() % b.toInt());
} else {
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"%zu, %zu",
(size_t)a.tag,
(size_t)b.tag);
}
}),

Expand All @@ -317,7 +342,13 @@ static Kernel prim_ops[] = {
if (a.isDouble()) {
out = EValue(static_cast<int64_t>(ceil(a.toDouble())));
} else {
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"Unsupported DType %zu",
(size_t)a.tag);
}
}),

Expand Down Expand Up @@ -348,7 +379,13 @@ static Kernel prim_ops[] = {

out = EValue(static_cast<int64_t>(res));
} else {
ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag);
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"Unsupported DType %zu",
(size_t)a.tag);
}
}),

Expand All @@ -362,7 +399,8 @@ static Kernel prim_ops[] = {
if (a.isDouble()) {
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
} else {
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
ET_KERNEL_CHECK_MSG(
context, false, InvalidType, /* void */, "%zu", (size_t)a.tag);
}
}),

Expand Down
5 changes: 3 additions & 2 deletions kernels/prim_ops/test/prim_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
EXPECT_EQ(stack[1]->toInt(), -5l);
}

TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorFails) {
testing::TensorFactory<ScalarType::Int> tf;

EValue values[2];
Expand All @@ -325,7 +325,8 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
}

// Try to negate a tensor, which should cause a runtime error.
ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context_, stack), "");
ET_EXPECT_KERNEL_FAILURE(
context_, getOpsFn("executorch_prim::neg.Scalar")(context_, stack));
}

TEST_F(RegisterPrimOpsTest, TestETView) {
Expand Down
Loading