diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 62aead8978f..7473d6a1ad0 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -22,12 +22,18 @@ namespace function { namespace { -#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \ - else { \ - ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \ +#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \ + else { \ + ET_KERNEL_CHECK_MSG( \ + context, \ + false, \ + InvalidType, \ + /* void */, \ + "%zu, %zu", \ + (size_t)a.tag, \ + (size_t)b.tag); \ } -// TODO Fail using runtime context #define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \ (void)context; \ EValue& a = *stack[0]; \ @@ -168,8 +174,14 @@ static Kernel prim_ops[] = { } else if (a.isDouble() && b.isInt()) { floor_div_double(a.toDouble(), static_cast(b.toInt()), out); } else { - // TODO Fail using runtime context - ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "%zu, %zu", + (size_t)a.tag, + (size_t)b.tag); } }), @@ -193,8 +205,14 @@ static Kernel prim_ops[] = { } else if (a.isDouble() && b.isInt()) { out = EValue(a.toDouble() / b.toInt()); } else { - // TODO Fail using runtime context - ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "%zu, %zu", + (size_t)a.tag, + (size_t)b.tag); } }), @@ -214,8 +232,8 @@ static Kernel prim_ops[] = { // TODO: This should be impossible out = EValue(a.toDouble()); } else { - // TODO Fail using runtime context - ET_CHECK_MSG(false, "%zu", (size_t)a.tag); + ET_KERNEL_CHECK_MSG( + context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), @@ -265,8 +283,8 @@ static Kernel prim_ops[] = { } else if (a.isDouble()) { out = EValue(-a.toDouble()); } else { - // TODO Fail using runtime context - ET_CHECK_MSG(false, "%zu", (size_t)a.tag); + ET_KERNEL_CHECK_MSG( + context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), @@ -303,7 +321,14 @@ static Kernel prim_ops[] = { if (a.isInt() && b.isInt()) { out = EValue(a.toInt() % b.toInt()); } else { - ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "%zu, %zu", + (size_t)a.tag, + (size_t)b.tag); } }), @@ -317,7 +342,13 @@ static Kernel prim_ops[] = { if (a.isDouble()) { out = EValue(static_cast(ceil(a.toDouble()))); } else { - ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "Unsupported DType %zu", + (size_t)a.tag); } }), @@ -348,7 +379,13 @@ static Kernel prim_ops[] = { out = EValue(static_cast(res)); } else { - ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "Unsupported DType %zu", + (size_t)a.tag); } }), @@ -362,7 +399,8 @@ static Kernel prim_ops[] = { if (a.isDouble()) { out = EValue(static_cast(trunc(a.toDouble()))); } else { - ET_CHECK_MSG(false, "%zu", (size_t)a.tag); + ET_KERNEL_CHECK_MSG( + context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 646d248cf79..f0131cb6a18 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -308,7 +308,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) { EXPECT_EQ(stack[1]->toInt(), -5l); } -TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) { +TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorFails) { testing::TensorFactory tf; EValue values[2]; @@ -325,7 +325,8 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) { } // Try to negate a tensor, which should cause a runtime error. - ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context_, stack), ""); + ET_EXPECT_KERNEL_FAILURE( + context_, getOpsFn("executorch_prim::neg.Scalar")(context_, stack)); } TEST_F(RegisterPrimOpsTest, TestETView) {