diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index 7a1ed0ef4ac..c11c9977fe5 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -36,11 +36,17 @@ Tensor& opt_add_out( a_type != ScalarType::Half) { // Resize for dynamic shape auto error = resize_tensor(out, a.sizes()); - ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() { CTYPE alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); using Vec = executorch::vec::Vectorized; executorch::vec::map2( @@ -53,7 +59,7 @@ Tensor& opt_add_out( } else { ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_CHECK(canCast(common_type, out_type)); + ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); ET_KERNEL_CHECK( ctx, @@ -66,7 +72,10 @@ Tensor& opt_add_out( ET_SWITCH_REALB_TYPES(common_type, ctx, "add.out", CTYPE_IN, [&]() { ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); + ET_KERNEL_CHECK( + ctx, + utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); apply_binary_elementwise_fn( [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 57d7c05ea74..a532cfc7ba6 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -29,6 +29,8 @@ Tensor& add_out( InvalidArgument, out); + ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType alpha_type = utils::get_scalar_dtype(alpha); @@ -81,6 +83,8 @@ Tensor& add_scalar_out( out, "Failed to resize output tensor."); + ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); ScalarType alpha_type = utils::get_scalar_dtype(alpha); diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp index bb0be9a4c1b..bd0b6e68445 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp @@ -23,6 +23,8 @@ Tensor& unary_ufunc_realhb_to_floath( Tensor& out) { (void)ctx; + ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); + // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, diff --git a/kernels/portable/cpu/util/activation_ops_util.cpp b/kernels/portable/cpu/util/activation_ops_util.cpp index b697c49e04f..273f5d59595 100644 --- a/kernels/portable/cpu/util/activation_ops_util.cpp +++ b/kernels/portable/cpu/util/activation_ops_util.cpp @@ -15,6 +15,7 @@ namespace executor { bool check_gelu_args(const Tensor& in, string_view approximate, Tensor& out) { ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); + ET_LOG_AND_RETURN_IF_FALSE(in.scalar_type() != ScalarType::Bool); ET_LOG_MSG_AND_RETURN_IF_FALSE( approximate == "tanh" || approximate == "none", "Invalid approximation format: %.*s for gelu", diff --git a/kernels/portable/cpu/util/broadcast_util.cpp b/kernels/portable/cpu/util/broadcast_util.cpp index 4173c1b0856..64ef2086ffd 100644 --- a/kernels/portable/cpu/util/broadcast_util.cpp +++ b/kernels/portable/cpu/util/broadcast_util.cpp @@ -198,7 +198,10 @@ Tensor broadcast_tensor( repeats[i] = 1; } } - repeat_tensor(broadcast_from, makeArrayRef(repeats, ndim), out); + + ET_CHECK( + repeat_tensor(broadcast_from, makeArrayRef(repeats, ndim), out) == + Error::Ok); free(repeats); diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 6ca1cf7ee97..77f42c266ad 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -97,7 +97,7 @@ __ET_DEPRECATED exec_aten::Tensor broadcast_tensor( * @param[out] out_dim The dimension of the broadcasted target * tensor */ -[[nodiscard]] Error get_broadcast_target_size( +__ET_NODISCARD Error get_broadcast_target_size( const exec_aten::ArrayRef a_size, const exec_aten::ArrayRef b_size, Tensor::SizesType* out_sizes, @@ -115,7 +115,7 @@ __ET_DEPRECATED exec_aten::Tensor broadcast_tensor( * @param[out] out_dim The dimension of the broadcasted target * tensor */ -[[nodiscard]] Error get_broadcast_target_size( +__ET_NODISCARD Error get_broadcast_target_size( const Tensor& a, const Tensor& b, Tensor::SizesType* out_sizes, @@ -130,7 +130,7 @@ __ET_DEPRECATED exec_aten::Tensor broadcast_tensor( * @param[in] b The second tensor going to be broadcasted. * @param[out] out The output tensor that will be resized. */ -[[nodiscard]] inline Error +__ET_NODISCARD inline Error resize_to_broadcast_target_size(const Tensor& a, const Tensor& b, Tensor& out) { Tensor::SizesType expected_output_size[kTensorDimensionLimit]; size_t expected_output_dim = 0; @@ -156,7 +156,7 @@ resize_to_broadcast_target_size(const Tensor& a, const Tensor& b, Tensor& out) { * @param[in] c The third tensor going to be broadcasted. * @param[out] out The output tensor that will be resized. */ -[[nodiscard]] inline Error resize_to_broadcast_target_size( +__ET_NODISCARD inline Error resize_to_broadcast_target_size( const Tensor& a, const Tensor& b, const Tensor& c, diff --git a/kernels/portable/cpu/util/copy_ops_util.cpp b/kernels/portable/cpu/util/copy_ops_util.cpp index ae48dee0fb4..5b54cd6890d 100644 --- a/kernels/portable/cpu/util/copy_ops_util.cpp +++ b/kernels/portable/cpu/util/copy_ops_util.cpp @@ -114,6 +114,7 @@ bool check_cat_args( // Ensure dim is in range. ET_LOG_AND_RETURN_IF_FALSE( tensors[ref_i].numel() == 0 || tensors[ref_i].dim() > dim); + ET_LOG_AND_RETURN_IF_FALSE(dim >= 0); return true; } @@ -378,6 +379,7 @@ bool check_slice_copy_args( int64_t dim, int64_t step, Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0); ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim)); ET_LOG_MSG_AND_RETURN_IF_FALSE( @@ -737,6 +739,8 @@ bool check_unsqueeze_copy_args( const Tensor input, int64_t dim, const Tensor out) { + ET_LOG_AND_RETURN_IF_FALSE(dim >= 0); + // The input and out shall share same dtype ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, out)); diff --git a/kernels/portable/cpu/util/kernel_ops_util.cpp b/kernels/portable/cpu/util/kernel_ops_util.cpp index 384b1859b22..fdbc5a0e532 100644 --- a/kernels/portable/cpu/util/kernel_ops_util.cpp +++ b/kernels/portable/cpu/util/kernel_ops_util.cpp @@ -462,6 +462,8 @@ bool check_slice_scatter_args( int64_t num_values, int64_t step, Tensor output) { + ET_LOG_AND_RETURN_IF_FALSE(input.dim() > 0); + // Check dim. The dim planed to be selected on shall exist in input ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, input.dim())); diff --git a/kernels/portable/cpu/util/repeat_util.cpp b/kernels/portable/cpu/util/repeat_util.cpp index bc721cd493c..9acb7ba088e 100644 --- a/kernels/portable/cpu/util/repeat_util.cpp +++ b/kernels/portable/cpu/util/repeat_util.cpp @@ -20,12 +20,12 @@ using Tensor = exec_aten::Tensor; namespace { -void check_repeat_args( +bool check_repeat_args( Tensor self, exec_aten::ArrayRef repeats, Tensor& out) { // Ensure the self tensors list is non-empty. - ET_CHECK_MSG( + ET_LOG_MSG_AND_RETURN_IF_FALSE( repeats.size() >= self.dim(), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); @@ -34,11 +34,11 @@ void check_repeat_args( for (auto repeat : repeats) { all_non_negative = all_non_negative && (repeat >= 0); } - ET_CHECK_MSG( + ET_LOG_MSG_AND_RETURN_IF_FALSE( all_non_negative, "Trying to create tensor with negative dimension"); /// Check if out.size() is legal. - ET_CHECK_MSG( + ET_LOG_MSG_AND_RETURN_IF_FALSE( out.dim() == repeats.size(), "The dimension of out shall equal size of repeats, but now is %zd and %zd", out.dim(), @@ -47,12 +47,12 @@ void check_repeat_args( // Right now we only support the tensors whose dimension is no greater than // kTensorDimensionLimit. Only check out tensor because the number of // dimension of out tensor shall have more than or equal to self tensor - ET_CHECK_MSG( + ET_LOG_MSG_AND_RETURN_IF_FALSE( out.dim() <= kTensorDimensionLimit, "The dimension of input and output should not be larger than %zd", kTensorDimensionLimit); - ET_CHECK_SAME_DTYPE2(out, self); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(out, self)); // We pad one to the beginning of self.size() to make its length equal // repeats, and called it reformat_self_size. We then make point-to-point mul @@ -66,13 +66,15 @@ void check_repeat_args( reformat_self_size[out.dim() - 1 - i] = self.size(self.dim() - 1 - i); } for (size_t i = 0; i < repeats.size(); i++) { - ET_CHECK_MSG( + ET_LOG_MSG_AND_RETURN_IF_FALSE( reformat_self_size[i] * repeats[i] == out.size(i), "Expect out size at dimension %zu is %" PRId64 ", but now is %zd", i, reformat_self_size[i] * repeats[i], out.size(i)); } + + return true; } // Given the indices to a point in an n-D tensor, and the stride (in bytes) @@ -163,16 +165,19 @@ void repeat_internal( // TODO(gasoonjia): dynamic allocate array to support tensor dimension larger // than kTensorDimensionLimit. -Tensor& repeat_tensor( +Error repeat_tensor( const Tensor& self, exec_aten::ArrayRef repeats, Tensor& out) { - // Assert that the args are valid. - check_repeat_args(self, repeats, out); + // Verify that the args are valid. + ET_CHECK_OR_RETURN_ERROR( + check_repeat_args(self, repeats, out), + InvalidArgument, + "Repeat arguments are invalid."); // Returns out if out.numel == 0, nothing needs to be repeated. if (out.numel() == 0) { - return out; + return Error::Ok; } ssize_t element_size = out.element_size(); @@ -183,7 +188,7 @@ Tensor& repeat_tensor( const char* src = self.const_data_ptr(); char* dest = out.mutable_data_ptr(); memcpy(dest, src, element_size); - return out; + return Error::Ok; } // Treats zero-dim self as one-dim tensor with size {1}. @@ -274,7 +279,7 @@ Tensor& repeat_tensor( accum_offset *= out.size(i); } - return out; + return Error::Ok; } } // namespace executor diff --git a/kernels/portable/cpu/util/repeat_util.h b/kernels/portable/cpu/util/repeat_util.h index 68e72c8aa83..28f5cfa5556 100644 --- a/kernels/portable/cpu/util/repeat_util.h +++ b/kernels/portable/cpu/util/repeat_util.h @@ -20,9 +20,9 @@ namespace executor { * @param[in] The number of times to repeat this tensor along each dimension * @param[in] Output tensor to write to. * - * @returns Repeated tensor. + * @returns The status of the repeat operation. */ -exec_aten::Tensor& repeat_tensor( +Error repeat_tensor( const exec_aten::Tensor& in, exec_aten::ArrayRef repeats, exec_aten::Tensor& out); diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 135b8af5af8..f7ca5bce920 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -27,6 +27,7 @@ def define_common_targets(): ], exported_headers = ["repeat_util.h"], deps = [ + "//executorch/runtime/kernel:kernel_includes", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", ], diff --git a/kernels/portable/test/op_allclose_test.cpp b/kernels/portable/test/op_allclose_test.cpp index e3f9b494410..25dbebce2a8 100644 --- a/kernels/portable/test/op_allclose_test.cpp +++ b/kernels/portable/test/op_allclose_test.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -188,14 +189,16 @@ TEST(OpAllCloseTest, MismatchedInputShapesDeath) { TensorFactory tf_bool; Tensor out = tf_bool.zeros(/*sizes=*/{1}); - ET_EXPECT_KERNEL_FAILURE(allclose_out( - a, - b, - default_rtol, - default_atol, - /*equal_nan=*/false, - /*dummy_param=*/false, - out)); + ET_EXPECT_DEATH( + allclose_out( + a, + b, + default_rtol, + default_atol, + /*equal_nan=*/false, + /*dummy_param=*/false, + out), + ""); } TEST(OpAllCloseTest, MismatchedInputDtypesDeath) { @@ -208,14 +211,16 @@ TEST(OpAllCloseTest, MismatchedInputDtypesDeath) { TensorFactory tf_bool; Tensor out = tf_bool.zeros(/*sizes=*/{1}); - ET_EXPECT_KERNEL_FAILURE(allclose_out( - a, - b, - default_rtol, - default_atol, - /*equal_nan=*/false, - /*dummy_param=*/false, - out)); + ET_EXPECT_DEATH( + allclose_out( + a, + b, + default_rtol, + default_atol, + /*equal_nan=*/false, + /*dummy_param=*/false, + out), + ""); } TEST(OpAllCloseTest, IncorrectOutputDtypeDeath) { @@ -224,14 +229,16 @@ TEST(OpAllCloseTest, IncorrectOutputDtypeDeath) { Tensor b = tf_float.ones(/*sizes=*/{2, 2}); Tensor out = tf_float.zeros(/*sizes=*/{1}); - ET_EXPECT_KERNEL_FAILURE(allclose_out( - a, - b, - default_rtol, - default_atol, - /*equal_nan=*/false, - /*dummy_param=*/false, - out)); + ET_EXPECT_DEATH( + allclose_out( + a, + b, + default_rtol, + default_atol, + /*equal_nan=*/false, + /*dummy_param=*/false, + out), + ""); } TEST(OpAllCloseTest, IncorrectOutputShapeDeath) { @@ -241,14 +248,16 @@ TEST(OpAllCloseTest, IncorrectOutputShapeDeath) { TensorFactory tf_bool; Tensor out = tf_bool.zeros(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(allclose_out( - a, - b, - default_rtol, - default_atol, - /*equal_nan=*/false, - /*dummy_param=*/false, - out)); + ET_EXPECT_DEATH( + allclose_out( + a, + b, + default_rtol, + default_atol, + /*equal_nan=*/false, + /*dummy_param=*/false, + out), + ""); } TEST(OpAllCloseTest, FloatTensorsVaryWithinRelativeTolerance) { diff --git a/kernels/test/TestUtil.h b/kernels/test/TestUtil.h index a8ebc21c0f9..ed72dbc4128 100644 --- a/kernels/test/TestUtil.h +++ b/kernels/test/TestUtil.h @@ -13,6 +13,9 @@ #pragma once +#include +#include +#include #include #include @@ -21,16 +24,62 @@ * Ensure the kernel will fail when `_statement` is executed. * @param _statement Statement to execute. */ -#define ET_EXPECT_KERNEL_FAILURE(_statement) EXPECT_ANY_THROW(_statement) +#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ + EXPECT_ANY_THROW(_statement) -#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_statement, _matcher) \ +#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \ EXPECT_ANY_THROW(_statement) #else -#define ET_EXPECT_KERNEL_FAILURE(_statement) ET_EXPECT_DEATH(_statement, "") +#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \ + do { \ + _statement; \ + expect_failure(); \ + if ((_context).failure_state() == torch::executor::Error::Ok) { \ + ET_LOG(Error, "Expected kernel failure but found success."); \ + ADD_FAILURE(); \ + } \ + } while (false) -#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_statement, _matcher) \ - ET_EXPECT_DEATH(_statement, _matcher) +#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \ + do { \ + _statement; \ + expect_failure(); \ + if ((_context).failure_state() == torch::executor::Error::Ok) { \ + ET_LOG(Error, "Expected kernel failure but found success."); \ + ADD_FAILURE(); \ + } \ + } while (false) #endif // USE_ATEN_LIB + +/* + * Common test fixture for kernel / operator-level tests. Provides + * a runtime context object and verifies failure state post-execution. + */ +class OperatorTest : public ::testing::Test { + public: + OperatorTest() : expect_failure_(false) {} + + void SetUp() override { + torch::executor::runtime_init(); + } + + void TearDown() override { + // Validate error state. + if (!expect_failure_) { + EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok); + } else { + EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok); + } + } + + void expect_failure() { + expect_failure_ = true; + } + + protected: + exec_aten::RuntimeContext context_; + bool expect_failure_; +}; diff --git a/kernels/test/op_abs_test.cpp b/kernels/test/op_abs_test.cpp index e06911e8e88..b54cd971567 100644 --- a/kernels/test/op_abs_test.cpp +++ b/kernels/test/op_abs_test.cpp @@ -19,12 +19,14 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_abs_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::abs_outf(context, self, out); -} - -TEST(OpAbsTest, SanityCheck) { +class OpAbsTest : public OperatorTest { + protected: + Tensor& op_abs_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::abs_outf(context_, self, out); + } +}; + +TEST_F(OpAbsTest, SanityCheck) { TensorFactory tf; Tensor in = tf.make({1, 7}, {-3.0, -2.5, -1.01, 0.0, 1.01, 2.5, 3.0}); diff --git a/kernels/test/op_acos_test.cpp b/kernels/test/op_acos_test.cpp index 78bf257f8b7..9c9c9211be0 100644 --- a/kernels/test/op_acos_test.cpp +++ b/kernels/test/op_acos_test.cpp @@ -21,12 +21,49 @@ using exec_aten::Tensor; using exec_aten::TensorShapeDynamism; using torch::executor::testing::TensorFactory; -Tensor& op_acos_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::acos_outf(context, self, out); -} +class OpAcosOutTest : public OperatorTest { + protected: + Tensor& op_acos_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::acos_outf(context_, self, out); + } + + // Common testing for acos operator and all kinds of supported input types + template + void test_floating_point_acos_out( + const std::vector& out_shape = {1, 6}, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + TensorFactory tf_in; + TensorFactory tf_out; + + // Destination for the acos operator. + Tensor out = tf_out.zeros(out_shape, dynamism); + + // clang-format off + op_acos_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); + + // Check that it matches (or close to) the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf_out.make({1, 6}, { 1.570796, 0.000000, NAN, NAN, NAN, NAN })); + // clang-format on + } + + // Unhandled output dtypes. + template + void test_acos_invalid_output_dtype_dies() { + TensorFactory tf; + TensorFactory tf_out; + + const std::vector sizes = {2, 5}; + + Tensor in = tf.ones(sizes); + Tensor out = tf_out.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_acos_out(in, out)); + } +}; -TEST(OpAcosOutKernelTest, HandleBoolInput) { +TEST_F(OpAcosOutTest, HandleBoolInput) { TensorFactory tf_bool; TensorFactory tf_float; @@ -39,28 +76,7 @@ TEST(OpAcosOutKernelTest, HandleBoolInput) { EXPECT_TENSOR_CLOSE(op_acos_out(a, out), res); } -// Common testing for acos operator and all kinds of supported input types -template -void test_floating_point_acos_out( - const std::vector& out_shape = {1, 6}, - TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { - TensorFactory tf_in; - TensorFactory tf_out; - - // Destination for the acos operator. - Tensor out = tf_out.zeros(out_shape, dynamism); - - // clang-format off - op_acos_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); - - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make({1, 6}, { 1.570796, 0.000000, NAN, NAN, NAN, NAN })); - // clang-format on -} - -TEST(OpAcosOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputHalfOutputStaticDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -70,21 +86,21 @@ TEST(OpAcosOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputFloatOutputStaticDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputFloatOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acos_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputDoubleOutputStaticDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputDoubleOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acos_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputHalfOutputBoundDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -95,7 +111,7 @@ TEST(OpAcosOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputFloatOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acos_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -103,7 +119,7 @@ TEST(OpAcosOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputDoubleOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acos_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -111,7 +127,7 @@ TEST(OpAcosOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputFloatOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -122,7 +138,7 @@ TEST(OpAcosOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcosOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { +TEST_F(OpAcosOutTest, AllRealInputDoubleOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -133,21 +149,7 @@ TEST(OpAcosOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -// Unhandled output dtypes. -template -void test_acos_invalid_output_dtype_dies() { - TensorFactory tf; - TensorFactory tf_out; - - const std::vector sizes = {2, 5}; - - Tensor in = tf.ones(sizes); - Tensor out = tf_out.zeros(sizes); - - ET_EXPECT_KERNEL_FAILURE(op_acos_out(in, out)); -} - -TEST(OpAcosOutKernelTest, AllNonFloatOutputDTypeDies) { +TEST_F(OpAcosOutTest, AllNonFloatOutputDTypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_acos_invalid_output_dtype_dies(); ET_FORALL_INT_TYPES(TEST_ENTRY); @@ -155,7 +157,7 @@ TEST(OpAcosOutKernelTest, AllNonFloatOutputDTypeDies) { } // Mismatched shape tests. -TEST(OpAcosOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAcosOutTest, MismatchedInputShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched input shapes"; } @@ -164,5 +166,5 @@ TEST(OpAcosOutKernelTest, MismatchedInputShapesDies) { Tensor a = tf.ones(/*sizes=*/{4}); Tensor out = tf.ones(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(op_acos_out(a, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_acos_out(a, out)); } diff --git a/kernels/test/op_acosh_test.cpp b/kernels/test/op_acosh_test.cpp index 69d3c8c790c..ce01411fd3f 100644 --- a/kernels/test/op_acosh_test.cpp +++ b/kernels/test/op_acosh_test.cpp @@ -21,12 +21,49 @@ using exec_aten::Tensor; using exec_aten::TensorShapeDynamism; using torch::executor::testing::TensorFactory; -Tensor& op_acosh_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::acosh_outf(context, self, out); -} +class OpAcoshOutTest : public OperatorTest { + protected: + Tensor& op_acosh_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::acosh_outf(context_, self, out); + } + + // Common testing for acosh operator and all kinds of supported input types + template + void test_floating_point_acosh_out( + const std::vector& out_shape = {1, 6}, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + TensorFactory tf_in; + TensorFactory tf_out; + + // Destination for the acosh operator. + Tensor out = tf_out.zeros(out_shape, dynamism); + + // clang-format off + op_acosh_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); + + // Check that it matches (or close to) the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf_out.make({1, 6}, { NAN, 0.000000, 1.762747, 2.292432, 2.993223, 5.298292 })); + // clang-format on + } + + // Unhandled output dtypes. + template + void test_acosh_invalid_output_dtype_dies() { + TensorFactory tf; + TensorFactory tf_out; + + const std::vector sizes = {2, 5}; + + Tensor in = tf.ones(sizes); + Tensor out = tf_out.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_acosh_out(in, out)); + } +}; -TEST(OpAcoshOutKernelTest, HandleBoolInput) { +TEST_F(OpAcoshOutTest, HandleBoolInput) { TensorFactory tf_bool; TensorFactory tf_float; @@ -39,42 +76,21 @@ TEST(OpAcoshOutKernelTest, HandleBoolInput) { EXPECT_TENSOR_CLOSE(op_acosh_out(a, out), res); } -// Common testing for acosh operator and all kinds of supported input types -template -void test_floating_point_acosh_out( - const std::vector& out_shape = {1, 6}, - TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { - TensorFactory tf_in; - TensorFactory tf_out; - - // Destination for the acosh operator. - Tensor out = tf_out.zeros(out_shape, dynamism); - - // clang-format off - op_acosh_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); - - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make({1, 6}, { NAN, 0.000000, 1.762747, 2.292432, 2.993223, 5.298292 })); - // clang-format on -} - -TEST(OpAcoshOutKernelTest, AllRealInputFloatOutputStaticDynamismSupport) { +TEST_F(OpAcoshOutTest, AllRealInputFloatOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acosh_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAcoshOutKernelTest, AllRealInputDoubleOutputStaticDynamismSupport) { +TEST_F(OpAcoshOutTest, AllRealInputDoubleOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acosh_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAcoshOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { +TEST_F(OpAcoshOutTest, AllRealInputFloatOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acosh_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -82,7 +98,7 @@ TEST(OpAcoshOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcoshOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { +TEST_F(OpAcoshOutTest, AllRealInputDoubleOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_acosh_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -90,7 +106,7 @@ TEST(OpAcoshOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcoshOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { +TEST_F(OpAcoshOutTest, AllRealInputFloatOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -101,7 +117,7 @@ TEST(OpAcoshOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAcoshOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { +TEST_F(OpAcoshOutTest, AllRealInputDoubleOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -112,21 +128,7 @@ TEST(OpAcoshOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -// Unhandled output dtypes. -template -void test_acosh_invalid_output_dtype_dies() { - TensorFactory tf; - TensorFactory tf_out; - - const std::vector sizes = {2, 5}; - - Tensor in = tf.ones(sizes); - Tensor out = tf_out.zeros(sizes); - - ET_EXPECT_KERNEL_FAILURE(op_acosh_out(in, out)); -} - -TEST(OpAcoshOutKernelTest, AllNonFloatOutputDTypeDies) { +TEST_F(OpAcoshOutTest, AllNonFloatOutputDTypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_acosh_invalid_output_dtype_dies(); ET_FORALL_INT_TYPES(TEST_ENTRY); @@ -134,7 +136,7 @@ TEST(OpAcoshOutKernelTest, AllNonFloatOutputDTypeDies) { } // Mismatched shape tests. -TEST(OpAcoshOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAcoshOutTest, MismatchedInputShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched input shapes"; } @@ -144,5 +146,5 @@ TEST(OpAcoshOutKernelTest, MismatchedInputShapesDies) { Tensor a = tf.ones(/*sizes=*/{4}); Tensor out = tf.ones(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(op_acosh_out(a, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_acosh_out(a, out)); } diff --git a/kernels/test/op_add_test.cpp b/kernels/test/op_add_test.cpp index 834205eff65..50e11002c70 100644 --- a/kernels/test/op_add_test.cpp +++ b/kernels/test/op_add_test.cpp @@ -24,115 +24,119 @@ using exec_aten::Tensor; using torch::executor::testing::SupportedFeatures; using torch::executor::testing::TensorFactory; -Tensor& op_add_out( - const Tensor& self, - const Tensor& other, - const Scalar& alpha, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::add_outf(context, self, other, alpha, out); -} - -Tensor& op_add_scalar_out( - const Tensor& self, - const Scalar& other, - const Scalar& alpha, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::add_outf(context, self, other, alpha, out); -} +class OpAddOutKernelTest : public OperatorTest { + protected: + Tensor& op_add_out( + const Tensor& self, + const Tensor& other, + const Scalar& alpha, + Tensor& out) { + return torch::executor::aten::add_outf(context_, self, other, alpha, out); + } -template -void test_add() { - TensorFactory tf_a; - TensorFactory tf_b; - TensorFactory tf_out; + template + void test_add() { + TensorFactory tf_a; + TensorFactory tf_b; + TensorFactory tf_out; - const std::vector sizes = {2, 2}; + const std::vector sizes = {2, 2}; - // Destination for the sum. - Tensor out = tf_out.zeros(sizes); + // Destination for the sum. + Tensor out = tf_out.zeros(sizes); - // Add two tensors. - op_add_out( - tf_a.make(sizes, /*data=*/{1, 2, 4, 8}), - tf_b.ones(sizes), - /*alpha=*/1, - out); + // Add two tensors. + op_add_out( + tf_a.make(sizes, /*data=*/{1, 2, 4, 8}), + tf_b.ones(sizes), + /*alpha=*/1, + out); - // Check that it matches the expected output. - EXPECT_TENSOR_EQ(out, tf_out.make(sizes, /*data=*/{2, 3, 5, 9})); -} + // Check that it matches the expected output. + EXPECT_TENSOR_EQ(out, tf_out.make(sizes, /*data=*/{2, 3, 5, 9})); + } -template -void test_add_enumerate_out_types() { - test_add(); - test_add(); - test_add(); - // Integral out type is only allowed if both inputs are integral types - if (isIntegralType(DTYPE_A, false) && isIntegralType(DTYPE_B, false)) { - test_add(); - test_add(); + template + void test_add_enumerate_out_types() { + test_add(); + test_add(); + test_add(); + // Integral out type is only allowed if both inputs are integral types + if (isIntegralType(DTYPE_A, false) && isIntegralType(DTYPE_B, false)) { + test_add(); + test_add(); + } } -} -template -void test_add_enumerate_b_types() { + template + void test_add_enumerate_b_types() { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_add_enumerate_out_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY -} + } -void test_add_enumerate_a_types() { + void test_add_enumerate_a_types() { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_add_enumerate_b_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY -} + } -/** - * Uses the function templates above to test all valid combinations of inputs - * and output dtypes - */ -TEST(OpAddOutKernelTest, AllRealDtypesSupported) { - test_add_enumerate_a_types(); -} + // Common testing for adding two floating point Tensors. + template + void test_floating_point_add_out() { + TensorFactory tf; -// Common testing for adding two floating point Tensors. -template -void test_floating_point_add_out() { - TensorFactory tf; + const std::vector sizes = {2, 2}; - const std::vector sizes = {2, 2}; + // Destination for the sum. + Tensor out = tf.zeros(sizes); - // Destination for the sum. - Tensor out = tf.zeros(sizes); + // Add two tensors. + op_add_out( + tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), + tf.ones(sizes), + /*alpha=*/1.1, + out); - // Add two tensors. - op_add_out( - tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), - tf.ones(sizes), - /*alpha=*/1.1, - out); + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.2, 3.3, 5.5, 9.9})); + } +}; + +class OpAddScalarOutKernelTest : public OperatorTest { + protected: + Tensor& op_add_scalar_out( + const Tensor& self, + const Scalar& other, + const Scalar& alpha, + Tensor& out) { + return torch::executor::aten::add_outf(context_, self, other, alpha, out); + } +}; - // Check that it matches the expected output. - EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.2, 3.3, 5.5, 9.9})); +/** + * Uses the function templates above to test all valid combinations of inputs + * and output dtypes + */ +TEST_F(OpAddOutKernelTest, AllRealDtypesSupported) { + test_add_enumerate_a_types(); } -TEST(OpAddOutKernelTest, FloatTensors) { +TEST_F(OpAddOutKernelTest, FloatTensors) { test_floating_point_add_out(); } -TEST(OpAddOutKernelTest, DoubleTensors) { +TEST_F(OpAddOutKernelTest, DoubleTensors) { test_floating_point_add_out(); } -TEST(OpAddOutKernelTest, BoolAndIntInputTensor) { +TEST_F(OpAddOutKernelTest, BoolAndIntInputTensor) { TensorFactory tf; TensorFactory tfi; @@ -147,7 +151,7 @@ TEST(OpAddOutKernelTest, BoolAndIntInputTensor) { EXPECT_TENSOR_EQ(out, tfi.make(sizes, {2, 5, 3, 4})); } -TEST(OpAddOutKernelTest, BoolAndBoolInputTensor) { +TEST_F(OpAddOutKernelTest, BoolAndBoolInputTensor) { et_pal_init(); TensorFactory tf; @@ -162,7 +166,7 @@ TEST(OpAddOutKernelTest, BoolAndBoolInputTensor) { EXPECT_TENSOR_EQ(out, tf.make(sizes, {false, true, true, true})); } -TEST(OpAddOutKernelTest, BroadcastDimSizeIsOneAB) { +TEST_F(OpAddOutKernelTest, BroadcastDimSizeIsOneAB) { TensorFactory tf; Tensor x = tf.make( @@ -188,7 +192,7 @@ TEST(OpAddOutKernelTest, BroadcastDimSizeIsOneAB) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, BroadcastDimSizeMissingAB) { +TEST_F(OpAddOutKernelTest, BroadcastDimSizeMissingAB) { TensorFactory tf; Tensor x = tf.make( @@ -214,7 +218,7 @@ TEST(OpAddOutKernelTest, BroadcastDimSizeMissingAB) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, BroadcastDimSizeIsOneBA) { +TEST_F(OpAddOutKernelTest, BroadcastDimSizeIsOneBA) { TensorFactory tf; Tensor x = tf.make({1, 2}, {0.7453382015228271, 0.3131374716758728}); @@ -240,7 +244,7 @@ TEST(OpAddOutKernelTest, BroadcastDimSizeIsOneBA) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, BroadcastDimSizeMissingBA) { +TEST_F(OpAddOutKernelTest, BroadcastDimSizeMissingBA) { TensorFactory tf; Tensor x = tf.make({1, 2}, {0.7453382015228271, 0.3131374716758728}); @@ -266,7 +270,7 @@ TEST(OpAddOutKernelTest, BroadcastDimSizeMissingBA) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, BroadcastSupported) { +TEST_F(OpAddOutKernelTest, BroadcastSupported) { TensorFactory tf; const std::vector sizes = {2, 2}; @@ -288,7 +292,7 @@ TEST(OpAddOutKernelTest, BroadcastSupported) { // Death Tests // -TEST(OpAddOutKernelTest, IntInputsFloatAlphaDies) { +TEST_F(OpAddOutKernelTest, IntInputsFloatAlphaDies) { // op_add_out() doesn't handle floating alpha for intergal inputs TensorFactory tf; @@ -300,10 +304,10 @@ TEST(OpAddOutKernelTest, IntInputsFloatAlphaDies) { // Elementwise add operation on two integral tensor with floating alpha // should cause an assertion and kill the test process. ET_EXPECT_KERNEL_FAILURE( - op_add_out(tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out)); + context_, op_add_out(tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out)); } -TEST(OpAddOutKernelTest, BoolInputsFloatAlphaDies) { +TEST_F(OpAddOutKernelTest, BoolInputsFloatAlphaDies) { // op_add_out() doesn't handle floating alpha for intergal inputs TensorFactory tf; @@ -315,10 +319,10 @@ TEST(OpAddOutKernelTest, BoolInputsFloatAlphaDies) { // Elementwise add operation on two integral tensor with floating alpha // should cause an assertion and kill the test process. ET_EXPECT_KERNEL_FAILURE( - op_add_out(tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out)); + context_, op_add_out(tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out)); } -TEST(OpAddOutKernelTest, IntOutputWithFloatInputDies) { +TEST_F(OpAddOutKernelTest, IntOutputWithFloatInputDies) { TensorFactory tfi; TensorFactory tff; @@ -331,10 +335,10 @@ TEST(OpAddOutKernelTest, IntOutputWithFloatInputDies) { // Destination for the sum. Tensor out = tfi.zeros(sizes); - ET_EXPECT_KERNEL_FAILURE(op_add_out(a, b, /*alpha=*/1, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, /*alpha=*/1, out)); } -TEST(OpAddOutKernelTest, BoolOutputWithIntegralInput) { +TEST_F(OpAddOutKernelTest, BoolOutputWithIntegralInput) { // op_add_out() doesn't handle Bool. TensorFactory tf; TensorFactory tfi; @@ -348,10 +352,10 @@ TEST(OpAddOutKernelTest, BoolOutputWithIntegralInput) { // Destination for the sum. Tensor out = tf.zeros(sizes); - ET_EXPECT_KERNEL_FAILURE(op_add_out(a, b, /*alpha=*/1, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, /*alpha=*/1, out)); } -TEST(OpAddOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAddOutKernelTest, MismatchedInputShapesDies) { TensorFactory tf; // Addends with different shapes. @@ -363,10 +367,10 @@ TEST(OpAddOutKernelTest, MismatchedInputShapesDies) { // Adding the two mismatched tensors should cause an assertion and kill the // test process. - ET_EXPECT_KERNEL_FAILURE(op_add_out(a, b, /*unused=*/0, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, /*unused=*/0, out)); } -TEST(OpAddOutKernelTest, MismatchedOutputShapesDies) { +TEST_F(OpAddOutKernelTest, MismatchedOutputShapesDies) { if (SupportedFeatures::get()->output_resize) { GTEST_SKIP() << "The current kernel supports implicitly resizing output tensor"; @@ -385,10 +389,10 @@ TEST(OpAddOutKernelTest, MismatchedOutputShapesDies) { // Adding the tensors into a mismatched output should cause an assertion and // kill the test process. - ET_EXPECT_KERNEL_FAILURE(op_add_out(a, b, /*unused=*/0, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, /*unused=*/0, out)); } -TEST(OpAddOutKernelTest, SimpleGeneratedCase) { +TEST_F(OpAddOutKernelTest, SimpleGeneratedCase) { et_pal_init(); TensorFactory tf; @@ -429,7 +433,7 @@ TEST(OpAddOutKernelTest, SimpleGeneratedCase) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { +TEST_F(OpAddOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { TensorFactory tf; Tensor x = tf.make( @@ -463,7 +467,7 @@ TEST(OpAddOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { +TEST_F(OpAddOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { TensorFactory tf; Tensor x = tf.make( @@ -497,7 +501,7 @@ TEST(OpAddOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddOutKernelTest, DynamicShapeUnbound) { +TEST_F(OpAddOutKernelTest, DynamicShapeUnbound) { GTEST_SKIP() << "Dynamic shape not supported"; TensorFactory tf; @@ -532,7 +536,7 @@ TEST(OpAddOutKernelTest, DynamicShapeUnbound) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddScalarOutKernelTest, SanityCheck) { +TEST_F(OpAddScalarOutKernelTest, SanityCheck) { TensorFactory tf; const std::vector sizes = {2, 2}; @@ -545,7 +549,7 @@ TEST(OpAddScalarOutKernelTest, SanityCheck) { EXPECT_TENSOR_EQ(out, tf.make(sizes, {3, 4, 6, 10})); } -TEST(OpAddScalarOutKernelTest, OptimizedSanityCheck) { +TEST_F(OpAddScalarOutKernelTest, OptimizedSanityCheck) { TensorFactory tf; const std::vector sizes = {2, 2}; diff --git a/kernels/test/op_addmm_test.cpp b/kernels/test/op_addmm_test.cpp index a5425183c8f..b8f33289fc5 100644 --- a/kernels/test/op_addmm_test.cpp +++ b/kernels/test/op_addmm_test.cpp @@ -24,19 +24,51 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_addmm_out( - const Tensor& self, - const Tensor& mat1, - const Tensor& mat2, - const Scalar& beta, - const Scalar& alpha, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::addmm_outf( - context, self, mat1, mat2, beta, alpha, out); -} +class OpAddmmOutTest : public OperatorTest { + protected: + Tensor& op_addmm_out( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + Tensor& out) { + return torch::executor::aten::addmm_outf( + context_, self, mat1, mat2, beta, alpha, out); + } + + template + void test_dtype() { + TensorFactory tf; + + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + if (DTYPE == ScalarType::Half) { + GTEST_SKIP() + << "skip Half because torch::executor::aten::mm_out does not support Half"; + return; + } + } + + // matmul gives 4 * 2 * 3 = 24, α * 24 = 48, 48 + β * self = 51 + Tensor self = tf.full({3, 5}, 1); + Tensor x = tf.full({3, 4}, 2); + Tensor y = tf.full({4, 5}, 3); + + // Output shape should be (3, 5) + Tensor out = tf.zeros({3, 5}); -TEST(OpAddmmOutTest, OutputDim) { + Scalar alpha = Scalar(2.0); + Scalar beta = Scalar(3.0); + + op_addmm_out(self, x, y, beta, alpha, out); + + Tensor expected = tf.full({3, 5}, 51); + + EXPECT_TENSOR_EQ(out, expected); + } +}; + +TEST_F(OpAddmmOutTest, OutputDim) { TensorFactory tf; // 3 tensors with compatible dimensions: (3, 5), (3, 4) and (4, 5). @@ -63,37 +95,7 @@ TEST(OpAddmmOutTest, OutputDim) { /// A generic smoke test that works for any dtype that supports ones() and /// zeros(). -template -void test_dtype() { - TensorFactory tf; - - if (torch::executor::testing::SupportedFeatures::get()->is_aten) { - if (DTYPE == ScalarType::Half) { - GTEST_SKIP() - << "skip Half because torch::executor::aten::mm_out does not support Half"; - return; - } - } - - // matmul gives 4 * 2 * 3 = 24, α * 24 = 48, 48 + β * self = 51 - Tensor self = tf.full({3, 5}, 1); - Tensor x = tf.full({3, 4}, 2); - Tensor y = tf.full({4, 5}, 3); - - // Output shape should be (3, 5) - Tensor out = tf.zeros({3, 5}); - - Scalar alpha = Scalar(2.0); - Scalar beta = Scalar(3.0); - - op_addmm_out(self, x, y, beta, alpha, out); - - Tensor expected = tf.full({3, 5}, 51); - - EXPECT_TENSOR_EQ(out, expected); -} - -TEST(OpAddmmOutTest, AllDtypesSupported) { +TEST_F(OpAddmmOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY); #undef TEST_ENTRY @@ -102,7 +104,7 @@ TEST(OpAddmmOutTest, AllDtypesSupported) { // for those types. } -TEST(OpAddmmOutTest, EmptyInputWithEmptyOutTensorPasses) { +TEST_F(OpAddmmOutTest, EmptyInputWithEmptyOutTensorPasses) { TensorFactory tf; // Empty input matrices @@ -119,7 +121,7 @@ TEST(OpAddmmOutTest, EmptyInputWithEmptyOutTensorPasses) { op_addmm_out(self, x, y, Scalar(2), Scalar(3), out), expected); } -TEST(OpAddmmOutTest, FloatTensorDtypeAndIntScalarTypePasses) { +TEST_F(OpAddmmOutTest, FloatTensorDtypeAndIntScalarTypePasses) { // case 1: Tensor dtype float, scalar type int TensorFactory tff; // matmul gives 4 * 2 * 3 = 24, α * 24 = 72, 72 + β * self = 74 @@ -136,7 +138,7 @@ TEST(OpAddmmOutTest, FloatTensorDtypeAndIntScalarTypePasses) { op_addmm_out(self, x, y, Scalar(2), Scalar(3), out), expected); } -TEST(OpAddmmOutTest, IntTensorDtypeAndFloatScalarTypePasses) { +TEST_F(OpAddmmOutTest, IntTensorDtypeAndFloatScalarTypePasses) { // case 2: Tensor dtype int, scalar type loat TensorFactory tfi; // matmul gives 4 * 2 * 3 = 24, α * 24 = 72, 72 + β * self = 74 @@ -153,7 +155,7 @@ TEST(OpAddmmOutTest, IntTensorDtypeAndFloatScalarTypePasses) { op_addmm_out(self, x, y, Scalar(2.0), Scalar(3.0), out), expected); } -TEST(OpAddmmOutTest, InfinityTensorAndFloatScalarTypePasses) { +TEST_F(OpAddmmOutTest, InfinityTensorAndFloatScalarTypePasses) { // case 2: Tensor dtype int, scalar type loat TensorFactory tff; @@ -170,7 +172,7 @@ TEST(OpAddmmOutTest, InfinityTensorAndFloatScalarTypePasses) { op_addmm_out(self, x, y, Scalar(2), Scalar(3), out), expected); } -TEST(OpAddmmOutTest, MismatchedDimensionsDies) { +TEST_F(OpAddmmOutTest, MismatchedDimensionsDies) { TensorFactory tf; Tensor self = tf.full({2, 2}, 3); @@ -184,13 +186,13 @@ TEST(OpAddmmOutTest, MismatchedDimensionsDies) { Tensor expected = tf.full({2, 2}, 9); ET_EXPECT_KERNEL_FAILURE( - op_addmm_out(self, x, wrong_y, Scalar(1), Scalar(1), out)); + context_, op_addmm_out(self, x, wrong_y, Scalar(1), Scalar(1), out)); EXPECT_TENSOR_EQ( op_addmm_out(self, x, right_y, Scalar(1), Scalar(1), out), expected); } -TEST(OpAddmmOutTest, MismatchedDimensionSizeDies) { +TEST_F(OpAddmmOutTest, MismatchedDimensionSizeDies) { TensorFactory tf; Tensor self = tf.full({2, 2}, 3); Tensor x = tf.full({2, 2}, 3); @@ -208,12 +210,14 @@ TEST(OpAddmmOutTest, MismatchedDimensionSizeDies) { } ET_EXPECT_KERNEL_FAILURE( + context_, op_addmm_out(self, x, right_y, Scalar(1), Scalar(1), wrong_out)); ET_EXPECT_KERNEL_FAILURE( + context_, op_addmm_out(self, x, wrong_y, Scalar(1), Scalar(1), right_out)); } -TEST(OpAddmmOutTest, WrongOutShapeDies) { +TEST_F(OpAddmmOutTest, WrongOutShapeDies) { TensorFactory tf; Tensor self = tf.ones({10, 4}); Tensor x = tf.ones({10, 3}); @@ -229,14 +233,14 @@ TEST(OpAddmmOutTest, WrongOutShapeDies) { } ET_EXPECT_KERNEL_FAILURE( - op_addmm_out(self, x, y, Scalar(1), Scalar(1), wrong_out)); + context_, op_addmm_out(self, x, y, Scalar(1), Scalar(1), wrong_out)); EXPECT_TENSOR_EQ( op_addmm_out(self, x, y, Scalar(1), Scalar(1), right_out), tf.full({10, 4}, 4)); } -TEST(OpAddmmOutTest, BroadcastTest) { +TEST_F(OpAddmmOutTest, BroadcastTest) { TensorFactory tf; Tensor self = tf.make({1}, {1}); @@ -249,7 +253,7 @@ TEST(OpAddmmOutTest, BroadcastTest) { op_addmm_out(self, x, y, Scalar(1), Scalar(1), out), tf.make({2, 2}, {8, 11, 16, 23})); } -TEST(OpAddmmOutKernelTest, BroadcastDimSize1) { +TEST_F(OpAddmmOutTest, BroadcastDimSize1) { TensorFactory tf; Tensor x = tf.make({1, 2}, {0.9937992691993713, 0.7011417150497437}); @@ -301,7 +305,7 @@ TEST(OpAddmmOutKernelTest, BroadcastDimSize1) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddmmOutKernelTest, BroadcastDimSizeMissing) { +TEST_F(OpAddmmOutTest, BroadcastDimSizeMissing) { TensorFactory tf; Tensor x = tf.make({2}, {0.9937992691993713, 0.7011417150497437}); @@ -353,7 +357,7 @@ TEST(OpAddmmOutKernelTest, BroadcastDimSizeMissing) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddmmOutKernelTest, BroadcastDimSizeIsOne) { +TEST_F(OpAddmmOutTest, BroadcastDimSizeIsOne) { TensorFactory tf; Tensor x = tf.make({1, 2}, {0.9093303680419922, 0.37621551752090454}); @@ -405,7 +409,7 @@ TEST(OpAddmmOutKernelTest, BroadcastDimSizeIsOne) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddmmOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { +TEST_F(OpAddmmOutTest, DynamicShapeUpperBoundSameAsExpected) { TensorFactory tf; Tensor x = tf.make( @@ -465,7 +469,7 @@ TEST(OpAddmmOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddmmOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { +TEST_F(OpAddmmOutTest, DynamicShapeUpperBoundLargerThanExpected) { TensorFactory tf; Tensor x = tf.make( @@ -525,7 +529,7 @@ TEST(OpAddmmOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpAddmmOutKernelTest, DynamicShapeUnbound) { +TEST_F(OpAddmmOutTest, DynamicShapeUnbound) { GTEST_SKIP() << "Dynamic shape unbound not supported"; TensorFactory tf; diff --git a/kernels/test/op_alias_copy_test.cpp b/kernels/test/op_alias_copy_test.cpp index 15e2018e7df..daa8c52dfb5 100644 --- a/kernels/test/op_alias_copy_test.cpp +++ b/kernels/test/op_alias_copy_test.cpp @@ -17,9 +17,11 @@ using namespace ::testing; -exec_aten::Tensor& op_alias_copy_out( - const exec_aten::Tensor& self, - exec_aten::Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::alias_copy_outf(context, self, out); -} +class OpAliasCopyTest : public OperatorTest { + protected: + exec_aten::Tensor& op_alias_copy_out( + const exec_aten::Tensor& self, + exec_aten::Tensor& out) { + return torch::executor::aten::alias_copy_outf(context_, self, out); + } +}; diff --git a/kernels/test/op_amax_test.cpp b/kernels/test/op_amax_test.cpp index a87423843b0..77f0b7da1d9 100644 --- a/kernels/test/op_amax_test.cpp +++ b/kernels/test/op_amax_test.cpp @@ -23,46 +23,212 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_amax_out( - const Tensor& in, - ArrayRef dim, - bool keepdim, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::amax_outf(context, in, dim, keepdim, out); -} - -template -void test_amax_out_invalid_dimensions() { - TensorFactory tf; +class OpAmaxOutTest : public OperatorTest { + protected: + Tensor& op_amax_out( + const Tensor& in, + ArrayRef dim, + bool keepdim, + Tensor& out) { + return torch::executor::aten::amax_outf(context_, in, dim, keepdim, out); + } - // clang-format off - Tensor in = tf.make( - {2, 3, 4}, - { - 0, 1, 2, 4, - 4, 2, 1, 0, - 1, 0, 4, 2, + template + void test_amax_out_invalid_dimensions() { + TensorFactory tf; + + // clang-format off + Tensor in = tf.make( + {2, 3, 4}, + { + 0, 1, 2, 4, + 4, 2, 1, 0, + 1, 0, 4, 2, + + 4, 2, 1, 0, + 0, 1, 2, 4, + 1, 0, 4, 2, + }); + // clang-format on + Tensor out = tf.zeros({2, 3, 1}); + + // out-of-bound dim in dim list + int64_t dims_1[1] = {3}; + ArrayRef dim_list{ArrayRef{dims_1, 1}}; + ET_EXPECT_KERNEL_FAILURE( + context_, op_amax_out(in, dim_list, /*keepdim=*/true, out)); + + // the same dim appears multiple times in list of dims + int64_t dims_2[2] = {2, 2}; + dim_list = ArrayRef{dims_2, 2}; + ET_EXPECT_KERNEL_FAILURE( + context_, op_amax_out(in, dim_list, /*keepdim=*/true, out)); + } - 4, 2, 1, 0, - 0, 1, 2, 4, - 1, 0, 4, 2, - }); - // clang-format on - Tensor out = tf.zeros({2, 3, 1}); + template + void test_amax_out_invalid_shape() { + TensorFactory tf; + + // clang-format off + Tensor in = tf.make( + {2, 3, 4}, + { + 0, 1, 2, 4, + 4, 2, 1, 0, + 1, 0, 4, 2, + + 4, 2, 1, 0, + 0, 1, 2, 4, + 1, 0, 4, 2, + }); + // clang-format on + + // dimension size mismatch when keepdim is true + Tensor out = tf.zeros({2, 4}); + + int64_t dims_1[1] = {1}; + ArrayRef dim_list{ArrayRef{dims_1, 1}}; + ET_EXPECT_KERNEL_FAILURE( + context_, op_amax_out(in, dim_list, /*keepdim=*/true, out)); + + // dimension size mismatch when keepdim is false + out = tf.zeros({2, 1, 4}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_amax_out(in, dim_list, /*keepdim=*/false, out)); + } - // out-of-bound dim in dim list - int64_t dims_1[1] = {3}; - ArrayRef dim_list{ArrayRef{dims_1, 1}}; - ET_EXPECT_DEATH(op_amax_out(in, dim_list, /*keepdim=*/true, out), ""); + template + void test_amax_out_dtype() { + TensorFactory tf; + // clang-format off + Tensor in = tf.make( + {2, 3, 4}, + { + 0, 1, 2, 4, + 4, 2, 1, 0, + 1, 5, 4, 2, + + 4, 2, 1, 0, + 5, 1, 2, 4, + 7, 5, 4, 2, + }); + // clang-format on + + // keepdim=true should work + Tensor out = tf.zeros({2, 3, 1}); + int64_t dims_1[1] = {2}; + ArrayRef dim_list{ArrayRef{dims_1, 1}}; + + op_amax_out(in, dim_list, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE(out, tf.make( + {2, 3, 1}, + {4, 4, 5, 4, 5, 7})); + // clang-format on + + // keepdim=false should work + out = tf.zeros({2, 3}); + op_amax_out(in, dim_list, /*keepdim=*/false, out); + // clang-format off + EXPECT_TENSOR_CLOSE(out, tf.make( + {2, 3}, + {4, 4, 5, 4, 5, 7})); + // clang-format on + + // dim list with multiple dimensions should work + out = tf.zeros({1, 1, 4}); + int64_t dims_2[2] = {0, 1}; + dim_list = ArrayRef{dims_2, 2}; + op_amax_out(in, dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 4}, {7, 5, 4, 4})); + + out = tf.zeros({4}); + op_amax_out(in, dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({4}, {7, 5, 4, 4})); + + // dim list with negative dimensions should work + out = tf.zeros({2, 1, 4}); + int64_t dims_3[1] = {-2}; + dim_list = ArrayRef{dims_3, 1}; + op_amax_out(in, dim_list, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE(out, tf.make( + {2, 1, 4}, + { + 4, 5, 4, 4, + + 7, 5, 4, 4, + })); + // clang-format on + + // empty/null dim list should work + // clang-format off + in = tf.make( + {2, 2, 4}, + { + 8, 7, 5, 4, + 4, 3, 7, 9, + + 4, 2, 6, 8, + 8, 7, 3, 4, + }); + // clang-format on + out = tf.zeros({1, 1, 1}); + ArrayRef null_dim_list; + op_amax_out(in, null_dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {9})); + + ArrayRef empty_dim_list{ArrayRef{}}; + op_amax_out(in, empty_dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {9})); + + out = tf.zeros({}); + op_amax_out(in, null_dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({}, {9})); + + op_amax_out(in, empty_dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({}, {9})); + } - // the same dim appears multiple times in list of dims - int64_t dims_2[2] = {2, 2}; - dim_list = ArrayRef{dims_2, 2}; - ET_EXPECT_DEATH(op_amax_out(in, dim_list, /*keepdim=*/true, out), ""); -} + template <> + void test_amax_out_dtype() { + TensorFactory tf_bool; + // clang-format off + Tensor in = tf_bool.make( + {2, 3, 4}, + { + true, false, true, false, + false, false, false, false, + false, true, true, false, + + false, false, true, false, + false, false, false, true, + true, true, true, true, + }); + // clang-format on + + Tensor out = tf_bool.zeros({2, 3, 1}); + + // +/-inf and nan should work + op_amax_out(in, /*dim=*/-1, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE( + out, tf_bool.make( + {2, 3, 1}, + { + true, + false, + true, + + true, + true, + true + })); + // clang-format on + } +}; -TEST(OpAmaxOutTest, InvalidDimensionListDies) { +TEST_F(OpAmaxOutTest, InvalidDimensionListDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } @@ -72,37 +238,7 @@ TEST(OpAmaxOutTest, InvalidDimensionListDies) { #undef TEST_ENTRY } -template -void test_amax_out_invalid_shape() { - TensorFactory tf; - - // clang-format off - Tensor in = tf.make( - {2, 3, 4}, - { - 0, 1, 2, 4, - 4, 2, 1, 0, - 1, 0, 4, 2, - - 4, 2, 1, 0, - 0, 1, 2, 4, - 1, 0, 4, 2, - }); - // clang-format on - - // dimension size mismatch when keepdim is true - Tensor out = tf.zeros({2, 4}); - - int64_t dims_1[1] = {1}; - ArrayRef dim_list{ArrayRef{dims_1, 1}}; - ET_EXPECT_DEATH(op_amax_out(in, dim_list, /*keepdim=*/true, out), ""); - - // dimension size mismatch when keepdim is false - out = tf.zeros({2, 1, 4}); - ET_EXPECT_DEATH(op_amax_out(in, dim_list, /*keepdim=*/false, out), ""); -} - -TEST(OpAmaxOutTest, InvalidShapeDies) { +TEST_F(OpAmaxOutTest, InvalidShapeDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } @@ -112,7 +248,7 @@ TEST(OpAmaxOutTest, InvalidShapeDies) { #undef TEST_ENTRY } -TEST(OpAmaxOutTest, MismatchedDTypesDies) { +TEST_F(OpAmaxOutTest, MismatchedDTypesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } @@ -138,146 +274,17 @@ TEST(OpAmaxOutTest, MismatchedDTypesDies) { ArrayRef dim_list{ArrayRef{dims_1, 1}}; // out tensor should be of the same dtype with dtype when dtype is specified - ET_EXPECT_DEATH(op_amax_out(in, dim_list, /*keepdim=*/true, out), ""); -} - -template -void test_amax_out_dtype() { - TensorFactory tf; - // clang-format off - Tensor in = tf.make( - {2, 3, 4}, - { - 0, 1, 2, 4, - 4, 2, 1, 0, - 1, 5, 4, 2, - - 4, 2, 1, 0, - 5, 1, 2, 4, - 7, 5, 4, 2, - }); - // clang-format on - - // keepdim=true should work - Tensor out = tf.zeros({2, 3, 1}); - int64_t dims_1[1] = {2}; - ArrayRef dim_list{ArrayRef{dims_1, 1}}; - - op_amax_out(in, dim_list, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE(out, tf.make( - {2, 3, 1}, - {4, 4, 5, 4, 5, 7})); - // clang-format on - - // keepdim=false should work - out = tf.zeros({2, 3}); - op_amax_out(in, dim_list, /*keepdim=*/false, out); - // clang-format off - EXPECT_TENSOR_CLOSE(out, tf.make( - {2, 3}, - {4, 4, 5, 4, 5, 7})); - // clang-format on - - // dim list with multiple dimensions should work - out = tf.zeros({1, 1, 4}); - int64_t dims_2[2] = {0, 1}; - dim_list = ArrayRef{dims_2, 2}; - op_amax_out(in, dim_list, /*keepdim=*/true, out); - EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 4}, {7, 5, 4, 4})); - - out = tf.zeros({4}); - op_amax_out(in, dim_list, /*keepdim=*/false, out); - EXPECT_TENSOR_CLOSE(out, tf.make({4}, {7, 5, 4, 4})); - - // dim list with negative dimensions should work - out = tf.zeros({2, 1, 4}); - int64_t dims_3[1] = {-2}; - dim_list = ArrayRef{dims_3, 1}; - op_amax_out(in, dim_list, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE(out, tf.make( - {2, 1, 4}, - { - 4, 5, 4, 4, - - 7, 5, 4, 4, - })); - // clang-format on - - // empty/null dim list should work - // clang-format off - in = tf.make( - {2, 2, 4}, - { - 8, 7, 5, 4, - 4, 3, 7, 9, - - 4, 2, 6, 8, - 8, 7, 3, 4, - }); - // clang-format on - out = tf.zeros({1, 1, 1}); - ArrayRef null_dim_list; - op_amax_out(in, null_dim_list, /*keepdim=*/true, out); - EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {9})); - - ArrayRef empty_dim_list{ArrayRef{}}; - op_amax_out(in, empty_dim_list, /*keepdim=*/true, out); - EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {9})); - - out = tf.zeros({}); - op_amax_out(in, null_dim_list, /*keepdim=*/false, out); - EXPECT_TENSOR_CLOSE(out, tf.make({}, {9})); - - op_amax_out(in, empty_dim_list, /*keepdim=*/false, out); - EXPECT_TENSOR_CLOSE(out, tf.make({}, {9})); -} - -template <> -void test_amax_out_dtype() { - TensorFactory tf_bool; - // clang-format off - Tensor in = tf_bool.make( - {2, 3, 4}, - { - true, false, true, false, - false, false, false, false, - false, true, true, false, - - false, false, true, false, - false, false, false, true, - true, true, true, true, - }); - // clang-format on - - Tensor out = tf_bool.zeros({2, 3, 1}); - - // +/-inf and nan should work - op_amax_out(in, /*dim=*/-1, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE( - out, tf_bool.make( - {2, 3, 1}, - { - true, - false, - true, - - true, - true, - true - })); - // clang-format on + ET_EXPECT_KERNEL_FAILURE( + context_, op_amax_out(in, dim_list, /*keepdim=*/true, out)); } -TEST(OpAmaxOutTest, AllRealInputOutputPasses) { +TEST_F(OpAmaxOutTest, AllRealInputOutputPasses) { #define TEST_ENTRY(ctype, dtype) test_amax_out_dtype(); ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAmaxOutTest, InfinityAndNANTest) { +TEST_F(OpAmaxOutTest, InfinityAndNANTest) { TensorFactory tf_float; // clang-format off Tensor in = tf_float.make( diff --git a/kernels/test/op_amin_test.cpp b/kernels/test/op_amin_test.cpp index 4ab76a5bb7b..218ee3af742 100644 --- a/kernels/test/op_amin_test.cpp +++ b/kernels/test/op_amin_test.cpp @@ -23,46 +23,212 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_amin_out( - const Tensor& in, - ArrayRef dim, - bool keepdim, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::amin_outf(context, in, dim, keepdim, out); -} - -template -void test_amin_out_invalid_dimensions() { - TensorFactory tf; +class OpAminOutTest : public OperatorTest { + protected: + Tensor& op_amin_out( + const Tensor& in, + ArrayRef dim, + bool keepdim, + Tensor& out) { + return torch::executor::aten::amin_outf(context_, in, dim, keepdim, out); + } - // clang-format off - Tensor in = tf.make( - {2, 3, 4}, - { - 0, 1, 2, 4, - 4, 2, 1, 0, - 1, 0, 4, 2, + template + void test_amin_out_invalid_dimensions() { + TensorFactory tf; + + // clang-format off + Tensor in = tf.make( + {2, 3, 4}, + { + 0, 1, 2, 4, + 4, 2, 1, 0, + 1, 0, 4, 2, + + 4, 2, 1, 0, + 0, 1, 2, 4, + 1, 0, 4, 2, + }); + // clang-format on + Tensor out = tf.zeros({2, 3, 1}); + + // out-of-bound dim in dim list + int64_t dims_1[1] = {3}; + ArrayRef dim_list{ArrayRef{dims_1, 1}}; + ET_EXPECT_KERNEL_FAILURE( + context_, op_amin_out(in, dim_list, /*keepdim=*/true, out)); + + // the same dim appears multiple times in list of dims + int64_t dims_2[2] = {2, 2}; + dim_list = ArrayRef{dims_2, 2}; + ET_EXPECT_KERNEL_FAILURE( + context_, op_amin_out(in, dim_list, /*keepdim=*/true, out)); + } - 4, 2, 1, 0, - 0, 1, 2, 4, - 1, 0, 4, 2, - }); - // clang-format on - Tensor out = tf.zeros({2, 3, 1}); + template + void test_amin_out_invalid_shape() { + TensorFactory tf; + + // clang-format off + Tensor in = tf.make( + {2, 3, 4}, + { + 0, 1, 2, 4, + 4, 2, 1, 0, + 1, 0, 4, 2, + + 4, 2, 1, 0, + 0, 1, 2, 4, + 1, 0, 4, 2, + }); + // clang-format on + + // dimension size mismatch when keepdim is true + Tensor out = tf.zeros({2, 4}); + + int64_t dims_1[1] = {1}; + ArrayRef dim_list{ArrayRef{dims_1, 1}}; + ET_EXPECT_KERNEL_FAILURE( + context_, op_amin_out(in, dim_list, /*keepdim=*/true, out)); + + // dimension size mismatch when keepdim is false + out = tf.zeros({2, 1, 4}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_amin_out(in, dim_list, /*keepdim=*/false, out)); + } - // out-of-bound dim in dim list - int64_t dims_1[1] = {3}; - ArrayRef dim_list{ArrayRef{dims_1, 1}}; - ET_EXPECT_DEATH(op_amin_out(in, dim_list, /*keepdim=*/true, out), ""); + template + void test_amin_out_dtype() { + TensorFactory tf; + // clang-format off + Tensor in = tf.make( + {2, 3, 4}, + { + 0, 1, 2, 4, + 4, 2, 1, 0, + 1, 5, 4, 2, + + 4, 2, 1, 0, + 5, 1, 2, 4, + 7, 5, 4, 2, + }); + // clang-format on + + // keepdim=true should work + Tensor out = tf.zeros({2, 3, 1}); + int64_t dims_1[1] = {2}; + ArrayRef dim_list{ArrayRef{dims_1, 1}}; + + op_amin_out(in, dim_list, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE(out, tf.make( + {2, 3, 1}, + {0, 0, 1, 0, 1, 2})); + // clang-format on + + // keepdim=false should work + out = tf.zeros({2, 3}); + op_amin_out(in, dim_list, /*keepdim=*/false, out); + // clang-format off + EXPECT_TENSOR_CLOSE(out, tf.make( + {2, 3}, + {0, 0, 1, 0, 1, 2})); + // clang-format on + + // dim list with multiple dimensions should work + out = tf.zeros({1, 1, 4}); + int64_t dims_2[2] = {0, 1}; + dim_list = ArrayRef{dims_2, 2}; + op_amin_out(in, dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 4}, {0, 1, 1, 0})); + + out = tf.zeros({4}); + op_amin_out(in, dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({4}, {0, 1, 1, 0})); + + // dim list with negative dimensions should work + out = tf.zeros({2, 1, 4}); + int64_t dims_3[1] = {-2}; + dim_list = ArrayRef{dims_3, 1}; + op_amin_out(in, dim_list, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE(out, tf.make( + {2, 1, 4}, + { + 0, 1, 1, 0, + + 4, 1, 1, 0, + })); + // clang-format on + + // empty/null dim list should work + // clang-format off + in = tf.make( + {2, 2, 4}, + { + 8, 7, 5, 4, + 4, 3, 7, 9, + + 4, 2, 6, 8, + 8, 7, 3, 4, + }); + // clang-format on + out = tf.zeros({1, 1, 1}); + ArrayRef null_dim_list; + op_amin_out(in, null_dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {2})); + + ArrayRef empty_dim_list{ArrayRef{}}; + op_amin_out(in, empty_dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {2})); + + out = tf.zeros({}); + op_amin_out(in, null_dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({}, {2})); + + op_amin_out(in, empty_dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({}, {2})); + } - // the same dim appears multiple times in list of dims - int64_t dims_2[2] = {2, 2}; - dim_list = ArrayRef{dims_2, 2}; - ET_EXPECT_DEATH(op_amin_out(in, dim_list, /*keepdim=*/true, out), ""); -} + template <> + void test_amin_out_dtype() { + TensorFactory tf_bool; + // clang-format off + Tensor in = tf_bool.make( + {2, 3, 4}, + { + true, false, true, false, + false, false, false, false, + false, true, true, false, + + false, false, true, false, + false, false, false, true, + true, true, true, true, + }); + // clang-format on + + Tensor out = tf_bool.zeros({2, 3, 1}); + + // +/-inf and nan should work + op_amin_out(in, /*dim=*/-1, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE( + out, tf_bool.make( + {2, 3, 1}, + { + false, + false, + false, + + false, + false, + true + })); + // clang-format on + } +}; -TEST(OpAminOutTest, InvalidDimensionListDies) { +TEST_F(OpAminOutTest, InvalidDimensionListDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } @@ -72,37 +238,7 @@ TEST(OpAminOutTest, InvalidDimensionListDies) { #undef TEST_ENTRY } -template -void test_amin_out_invalid_shape() { - TensorFactory tf; - - // clang-format off - Tensor in = tf.make( - {2, 3, 4}, - { - 0, 1, 2, 4, - 4, 2, 1, 0, - 1, 0, 4, 2, - - 4, 2, 1, 0, - 0, 1, 2, 4, - 1, 0, 4, 2, - }); - // clang-format on - - // dimension size mismatch when keepdim is true - Tensor out = tf.zeros({2, 4}); - - int64_t dims_1[1] = {1}; - ArrayRef dim_list{ArrayRef{dims_1, 1}}; - ET_EXPECT_DEATH(op_amin_out(in, dim_list, /*keepdim=*/true, out), ""); - - // dimension size mismatch when keepdim is false - out = tf.zeros({2, 1, 4}); - ET_EXPECT_DEATH(op_amin_out(in, dim_list, /*keepdim=*/false, out), ""); -} - -TEST(OpAminOutTest, InvalidShapeDies) { +TEST_F(OpAminOutTest, InvalidShapeDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } @@ -112,7 +248,7 @@ TEST(OpAminOutTest, InvalidShapeDies) { #undef TEST_ENTRY } -TEST(OpAminOutTest, MismatchedDTypesDies) { +TEST_F(OpAminOutTest, MismatchedDTypesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } @@ -138,146 +274,17 @@ TEST(OpAminOutTest, MismatchedDTypesDies) { ArrayRef dim_list{ArrayRef{dims_1, 1}}; // out tensor should be of the same dtype with dtype when dtype is specified - ET_EXPECT_DEATH(op_amin_out(in, dim_list, /*keepdim=*/true, out), ""); -} - -template -void test_amin_out_dtype() { - TensorFactory tf; - // clang-format off - Tensor in = tf.make( - {2, 3, 4}, - { - 0, 1, 2, 4, - 4, 2, 1, 0, - 1, 5, 4, 2, - - 4, 2, 1, 0, - 5, 1, 2, 4, - 7, 5, 4, 2, - }); - // clang-format on - - // keepdim=true should work - Tensor out = tf.zeros({2, 3, 1}); - int64_t dims_1[1] = {2}; - ArrayRef dim_list{ArrayRef{dims_1, 1}}; - - op_amin_out(in, dim_list, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE(out, tf.make( - {2, 3, 1}, - {0, 0, 1, 0, 1, 2})); - // clang-format on - - // keepdim=false should work - out = tf.zeros({2, 3}); - op_amin_out(in, dim_list, /*keepdim=*/false, out); - // clang-format off - EXPECT_TENSOR_CLOSE(out, tf.make( - {2, 3}, - {0, 0, 1, 0, 1, 2})); - // clang-format on - - // dim list with multiple dimensions should work - out = tf.zeros({1, 1, 4}); - int64_t dims_2[2] = {0, 1}; - dim_list = ArrayRef{dims_2, 2}; - op_amin_out(in, dim_list, /*keepdim=*/true, out); - EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 4}, {0, 1, 1, 0})); - - out = tf.zeros({4}); - op_amin_out(in, dim_list, /*keepdim=*/false, out); - EXPECT_TENSOR_CLOSE(out, tf.make({4}, {0, 1, 1, 0})); - - // dim list with negative dimensions should work - out = tf.zeros({2, 1, 4}); - int64_t dims_3[1] = {-2}; - dim_list = ArrayRef{dims_3, 1}; - op_amin_out(in, dim_list, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE(out, tf.make( - {2, 1, 4}, - { - 0, 1, 1, 0, - - 4, 1, 1, 0, - })); - // clang-format on - - // empty/null dim list should work - // clang-format off - in = tf.make( - {2, 2, 4}, - { - 8, 7, 5, 4, - 4, 3, 7, 9, - - 4, 2, 6, 8, - 8, 7, 3, 4, - }); - // clang-format on - out = tf.zeros({1, 1, 1}); - ArrayRef null_dim_list; - op_amin_out(in, null_dim_list, /*keepdim=*/true, out); - EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {2})); - - ArrayRef empty_dim_list{ArrayRef{}}; - op_amin_out(in, empty_dim_list, /*keepdim=*/true, out); - EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {2})); - - out = tf.zeros({}); - op_amin_out(in, null_dim_list, /*keepdim=*/false, out); - EXPECT_TENSOR_CLOSE(out, tf.make({}, {2})); - - op_amin_out(in, empty_dim_list, /*keepdim=*/false, out); - EXPECT_TENSOR_CLOSE(out, tf.make({}, {2})); -} - -template <> -void test_amin_out_dtype() { - TensorFactory tf_bool; - // clang-format off - Tensor in = tf_bool.make( - {2, 3, 4}, - { - true, false, true, false, - false, false, false, false, - false, true, true, false, - - false, false, true, false, - false, false, false, true, - true, true, true, true, - }); - // clang-format on - - Tensor out = tf_bool.zeros({2, 3, 1}); - - // +/-inf and nan should work - op_amin_out(in, /*dim=*/-1, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE( - out, tf_bool.make( - {2, 3, 1}, - { - false, - false, - false, - - false, - false, - true - })); - // clang-format on + ET_EXPECT_KERNEL_FAILURE( + context_, op_amin_out(in, dim_list, /*keepdim=*/true, out)); } -TEST(OpAminOutTest, AllRealInputOutputPasses) { +TEST_F(OpAminOutTest, AllRealInputOutputPasses) { #define TEST_ENTRY(ctype, dtype) test_amin_out_dtype(); ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAminOutTest, InfinityAndNANTest) { +TEST_F(OpAminOutTest, InfinityAndNANTest) { TensorFactory tf_float; // clang-format off Tensor in = tf_float.make( diff --git a/kernels/test/op_any_test.cpp b/kernels/test/op_any_test.cpp index 9c4d5a9675b..09f9cdd4991 100644 --- a/kernels/test/op_any_test.cpp +++ b/kernels/test/op_any_test.cpp @@ -23,54 +23,82 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_any_all_out(const Tensor& input, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::any_outf(context, input, out); -} - -Tensor& op_any_dims_out( - const Tensor& input, - optional> dim, - bool keepdim, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::any_outf(context, input, dim, keepdim, out); -} +class OpAnyOutTest : public OperatorTest { + protected: + Tensor& op_any_all_out(const Tensor& input, Tensor& out) { + return torch::executor::aten::any_outf(context_, input, out); + } -Tensor& -op_any_out(const Tensor& input, int64_t dim, bool keepdim, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::any_outf(context, input, dim, keepdim, out); -} + Tensor& op_any_dims_out( + const Tensor& input, + optional> dim, + bool keepdim, + Tensor& out) { + return torch::executor::aten::any_outf(context_, input, dim, keepdim, out); + } -class OpAnyAllOutTest : public ::testing::Test { - protected: - void SetUp() override { - // Since these tests cause ET_LOG to be called, the PAL must be initialized - // first. - torch::executor::runtime_init(); + Tensor& + op_any_out(const Tensor& input, int64_t dim, bool keepdim, Tensor& out) { + return torch::executor::aten::any_outf(context_, input, dim, keepdim, out); } -}; -class OpAnyDimsOutTest : public ::testing::Test { - protected: - void SetUp() override { - // Since these tests cause ET_LOG to be called, the PAL must be initialized - // first. - torch::executor::runtime_init(); + template + void test_any_all_out_invalid_type() { + TensorFactory tf_float; + TensorFactory tf_out; + + Tensor in = tf_float.make( + {1, 4}, + { + 0, + 0, + 1, + 0, + }); + Tensor out = tf_out.zeros(/*size=*/{0}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_any_all_out(in, out)); } -}; -class OpAnyOutTest : public ::testing::Test { - protected: - void SetUp() override { - // Since these tests cause ET_LOG to be called, the PAL must be initialized - // first. - torch::executor::runtime_init(); + template + void test_any_all_out() { + TensorFactory tf_in; + TensorFactory tf_bool; + // clang-format off + Tensor in = tf_in.make( + {2, 4}, + { + 0, 1, 0, 1, + 1, 0, 1, 0 + }); + Tensor bool_false_in = tf_bool.make( + {2, 4}, + { + false, false, false, false, + false, false, false, false, + }); + Tensor bool_true_in = tf_bool.make( + {2, 4}, + { + true, true, true, true, + true, true, true, true, + }); + // clang-format on + + Tensor out = tf_bool.make({}, {false}); + + op_any_all_out(in, out); + EXPECT_TENSOR_EQ(out, tf_bool.make({}, {true})); + + op_any_all_out(bool_false_in, out); + EXPECT_TENSOR_EQ(out, tf_bool.make({}, {false})); + + op_any_all_out(bool_true_in, out); + EXPECT_TENSOR_EQ(out, tf_bool.make({}, {true})); } }; -TEST_F(OpAnyAllOutTest, MismatchedDimensionsDies) { +TEST_F(OpAnyOutTest, MismatchedDimensionsDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched dimensions"; } @@ -80,78 +108,23 @@ TEST_F(OpAnyAllOutTest, MismatchedDimensionsDies) { Tensor in = tff.make(size, {0, 0, 1, 0}); Tensor out = tff.ones(/*size=*/{1, 1}); - ET_EXPECT_KERNEL_FAILURE(op_any_all_out(in, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_any_all_out(in, out)); } -template -void test_any_all_out_invalid_type() { - TensorFactory tf_float; - TensorFactory tf_out; - - Tensor in = tf_float.make( - {1, 4}, - { - 0, - 0, - 1, - 0, - }); - Tensor out = tf_out.zeros(/*size=*/{0}); - - ET_EXPECT_KERNEL_FAILURE(op_any_all_out(in, out)); -} - -TEST_F(OpAnyAllOutTest, InvalidDtypeDies) { +TEST_F(OpAnyOutTest, InvalidDtypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_any_all_out_invalid_type(); ET_FORALL_FLOAT_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -template -void test_any_all_out() { - TensorFactory tf_in; - TensorFactory tf_bool; - // clang-format off - Tensor in = tf_in.make( - {2, 4}, - { - 0, 1, 0, 1, - 1, 0, 1, 0 - }); - Tensor bool_false_in = tf_bool.make( - {2, 4}, - { - false, false, false, false, - false, false, false, false, - }); - Tensor bool_true_in = tf_bool.make( - {2, 4}, - { - true, true, true, true, - true, true, true, true, - }); - // clang-format on - - Tensor out = tf_bool.make({}, {false}); - - op_any_all_out(in, out); - EXPECT_TENSOR_EQ(out, tf_bool.make({}, {true})); - - op_any_all_out(bool_false_in, out); - EXPECT_TENSOR_EQ(out, tf_bool.make({}, {false})); - - op_any_all_out(bool_true_in, out); - EXPECT_TENSOR_EQ(out, tf_bool.make({}, {true})); -} - -TEST_F(OpAnyAllOutTest, AllRealInputTypePasses) { +TEST_F(OpAnyOutTest, AllRealInputTypePasses) { #define TEST_ENTRY(ctype, dtype) test_any_all_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST_F(OpAnyDimsOutTest, SmokeTest) { +TEST_F(OpAnyOutTest, SmokeTestDims) { TensorFactory tfBool; Tensor self = tfBool.make({2, 3, 1}, {true, false, true, true, false, false}); diff --git a/kernels/test/op_arange_test.cpp b/kernels/test/op_arange_test.cpp index e355899f7aa..7bacf93c740 100644 --- a/kernels/test/op_arange_test.cpp +++ b/kernels/test/op_arange_test.cpp @@ -27,65 +27,72 @@ using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_arange_out(const Scalar& end, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::arange_outf(context, end, out); -} +class OpArangeOutTest : public OperatorTest { + protected: + Tensor& op_arange_out(const Scalar& end, Tensor& out) { + return torch::executor::aten::arange_outf(context_, end, out); + } -Tensor& op_arange_start_out( - const Scalar& start, - const Scalar& end, - const Scalar& step, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::arange_outf(context, start, end, step, out); -} + template + void test_arange_dtype() { + TensorFactory tf; -class OpArangeOutTest : public ::testing::Test { - protected: - void SetUp() override { - // Since these tests cause ET_LOG to be called, the PAL must be initialized - // first. - torch::executor::runtime_init(); + Scalar end = Scalar(static_cast(10)); + + Tensor out = tf.zeros({10}); + + Tensor ret = op_arange_out(end, out); + + // Should always return the provided out Tensor. + EXPECT_TENSOR_EQ(ret, out); + + // Expected tensor, filled with 0, 1, ..., 9 + Tensor expected = tf.make({10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + EXPECT_TENSOR_EQ(out, expected); } }; -class OpArangeStartOutTest : public ::testing::Test { +class OpArangeStartOutTest : public OperatorTest { protected: - void SetUp() override { - // Since these tests cause ET_LOG to be called, the PAL must be initialized - // first. - torch::executor::runtime_init(); + Tensor& op_arange_start_out( + const Scalar& start, + const Scalar& end, + const Scalar& step, + Tensor& out) { + return torch::executor::aten::arange_outf(context_, start, end, step, out); } -}; -/// A generic smoke test that works for any dtype that supports zeros(). -template -void test_arange_dtype() { - TensorFactory tf; + template + void test_arange_start_dtype() { + TensorFactory tf; - Scalar end = Scalar(static_cast(10)); + Scalar start = Scalar(static_cast(0)); + Scalar end = Scalar(static_cast(10)); + Scalar step = Scalar(static_cast(1)); - Tensor out = tf.zeros({10}); + Tensor out = tf.zeros({10}); - Tensor ret = op_arange_out(end, out); + Tensor ret = op_arange_start_out(start, end, step, out); - // Should always return the provided out Tensor. - EXPECT_TENSOR_EQ(ret, out); + // Should always return the provided out Tensor. + EXPECT_TENSOR_EQ(ret, out); - // Expected tensor, filled with 0, 1, ..., 9 - Tensor expected = tf.make({10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + // Expected tensor, filled with 0, 1, ..., 9 + Tensor expected = tf.make({10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - EXPECT_TENSOR_EQ(out, expected); -} + EXPECT_TENSOR_EQ(out, expected); + } +}; -TEST(OpArangeOutTest, AllRealDtypesSupported) { +/// A generic smoke test that works for any dtype that supports zeros(). +TEST_F(OpArangeOutTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_arange_dtype(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpArangeOutTest, FloatNumberNotEqualIntSupport) { +TEST_F(OpArangeOutTest, FloatNumberNotEqualIntSupport) { TensorFactory tf; // end = any floating point number between [a, a+1) where a is an arbitrary @@ -106,7 +113,7 @@ TEST(OpArangeOutTest, FloatNumberNotEqualIntSupport) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpArangeOutTest, OutDimUnsupportedDie) { +TEST_F(OpArangeOutTest, OutDimUnsupportedDie) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched out dim"; } @@ -117,10 +124,10 @@ TEST(OpArangeOutTest, OutDimUnsupportedDie) { Tensor out = tf.zeros({5, 1}); // out.dim() should be 1, not 2 - ET_EXPECT_KERNEL_FAILURE(op_arange_out(end, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_arange_out(end, out)); } -TEST(OpArangeOutTest, DynamicShapeUpperBoundSameAsExpected) { +TEST_F(OpArangeOutTest, DynamicShapeUpperBoundSameAsExpected) { TensorFactory tf; Tensor expected_result = tf.make({5}, {0, 1, 2, 3, 4}); @@ -131,7 +138,7 @@ TEST(OpArangeOutTest, DynamicShapeUpperBoundSameAsExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpArangeOutTest, DynamicShapeUpperBoundLargerThanExpected) { +TEST_F(OpArangeOutTest, DynamicShapeUpperBoundLargerThanExpected) { TensorFactory tf; Tensor expected_result = tf.make({5}, {0, 1, 2, 3, 4}); @@ -142,7 +149,7 @@ TEST(OpArangeOutTest, DynamicShapeUpperBoundLargerThanExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpArangeOutTest, DynamicShapeUnbound) { +TEST_F(OpArangeOutTest, DynamicShapeUnbound) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic Unbound not supported"; } @@ -157,35 +164,14 @@ TEST(OpArangeOutTest, DynamicShapeUnbound) { } /// A generic smoke test that works for any dtype that supports zeros(). -template -void test_arange_start_dtype() { - TensorFactory tf; - - Scalar start = Scalar(static_cast(0)); - Scalar end = Scalar(static_cast(10)); - Scalar step = Scalar(static_cast(1)); - - Tensor out = tf.zeros({10}); - - Tensor ret = op_arange_start_out(start, end, step, out); - - // Should always return the provided out Tensor. - EXPECT_TENSOR_EQ(ret, out); - - // Expected tensor, filled with 0, 1, ..., 9 - Tensor expected = tf.make({10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - - EXPECT_TENSOR_EQ(out, expected); -} - -TEST(OpArangeStartOutTest, AllRealDtypesSupported) { +TEST_F(OpArangeStartOutTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) \ test_arange_start_dtype(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpArangeStartOutTest, FloatNumberNotEqualIntSupport) { +TEST_F(OpArangeStartOutTest, FloatNumberNotEqualIntSupport) { TensorFactory tf; // Tested in bento: @@ -209,7 +195,7 @@ TEST(OpArangeStartOutTest, FloatNumberNotEqualIntSupport) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpArangeStartOutTest, OutDimUnsupportedDie) { +TEST_F(OpArangeStartOutTest, OutDimUnsupportedDie) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched out dim"; } @@ -222,10 +208,11 @@ TEST(OpArangeStartOutTest, OutDimUnsupportedDie) { Tensor out = tf.zeros({5, 1}); // out.dim() should be 1, not 2 - ET_EXPECT_KERNEL_FAILURE(op_arange_start_out(start, end, step, out)); + ET_EXPECT_KERNEL_FAILURE( + context_, op_arange_start_out(start, end, step, out)); } -TEST(OpArangeStartOutTest, DynamicShapeUpperBoundSameAsExpected) { +TEST_F(OpArangeStartOutTest, DynamicShapeUpperBoundSameAsExpected) { TensorFactory tf; Tensor expected_result = tf.make({5}, {0, 1, 2, 3, 4}); @@ -236,7 +223,7 @@ TEST(OpArangeStartOutTest, DynamicShapeUpperBoundSameAsExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpArangeStartOutTest, DynamicShapeUpperBoundLargerThanExpected) { +TEST_F(OpArangeStartOutTest, DynamicShapeUpperBoundLargerThanExpected) { TensorFactory tf; Tensor expected_result = tf.make({5}, {0, 1, 2, 3, 4}); @@ -247,7 +234,7 @@ TEST(OpArangeStartOutTest, DynamicShapeUpperBoundLargerThanExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpArangeStartOutTest, DynamicShapeUnbound) { +TEST_F(OpArangeStartOutTest, DynamicShapeUnbound) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic Unbound not supported"; } @@ -261,7 +248,7 @@ TEST(OpArangeStartOutTest, DynamicShapeUnbound) { EXPECT_TENSOR_CLOSE(out, expected_result); } -TEST(OpArangeStartOutTest, StartOut) { +TEST_F(OpArangeStartOutTest, StartOut) { TensorFactory tf; Scalar start = Scalar(1.1); @@ -294,7 +281,7 @@ TEST(OpArangeStartOutTest, StartOut) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpArangeStartOutTest, StartOutNegativeStep) { +TEST_F(OpArangeStartOutTest, StartOutNegativeStep) { TensorFactory tf; Scalar start = Scalar(5.5); diff --git a/kernels/test/op_argmax_test.cpp b/kernels/test/op_argmax_test.cpp index 96ef45bf0dd..51720ee9be7 100644 --- a/kernels/test/op_argmax_test.cpp +++ b/kernels/test/op_argmax_test.cpp @@ -22,16 +22,18 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_argmax_out( - const Tensor& in, - optional dim, - bool keepdim, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::argmax_outf(context, in, dim, keepdim, out); -} - -TEST(OpArgmaxTest, SanityCheckLong) { +class OpArgmaxTest : public OperatorTest { + protected: + Tensor& op_argmax_out( + const Tensor& in, + optional dim, + bool keepdim, + Tensor& out) { + return torch::executor::aten::argmax_outf(context_, in, dim, keepdim, out); + } +}; + +TEST_F(OpArgmaxTest, SanityCheckLong) { TensorFactory tf; // clang-format off @@ -56,7 +58,7 @@ TEST(OpArgmaxTest, SanityCheckLong) { // clang-format on } -TEST(OpArgmaxTest, SanityCheckShort) { +TEST_F(OpArgmaxTest, SanityCheckShort) { TensorFactory tfl; TensorFactory tfs; @@ -82,7 +84,7 @@ TEST(OpArgmaxTest, SanityCheckShort) { // clang-format on } -TEST(OpArgmaxTest, SanityCheckNullDim) { +TEST_F(OpArgmaxTest, SanityCheckNullDim) { TensorFactory tf; // clang-format off diff --git a/kernels/test/op_argmin_test.cpp b/kernels/test/op_argmin_test.cpp index 5cba45a642f..fd63e89ae69 100644 --- a/kernels/test/op_argmin_test.cpp +++ b/kernels/test/op_argmin_test.cpp @@ -22,16 +22,18 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_argmin_out( - const Tensor& in, - optional dim, - bool keepdim, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::argmin_outf(context, in, dim, keepdim, out); -} - -TEST(OpArgminTest, SanityCheckLong) { +class OpArgminTest : public OperatorTest { + protected: + Tensor& op_argmin_out( + const Tensor& in, + optional dim, + bool keepdim, + Tensor& out) { + return torch::executor::aten::argmin_outf(context_, in, dim, keepdim, out); + } +}; + +TEST_F(OpArgminTest, SanityCheckLong) { TensorFactory tf; // clang-format off @@ -56,7 +58,7 @@ TEST(OpArgminTest, SanityCheckLong) { // clang-format on } -TEST(OpArgminTest, SanityCheckShort) { +TEST_F(OpArgminTest, SanityCheckShort) { TensorFactory tfl; TensorFactory tfs; @@ -82,7 +84,7 @@ TEST(OpArgminTest, SanityCheckShort) { // clang-format on } -TEST(OpArgminTest, SanityCheckNullDim) { +TEST_F(OpArgminTest, SanityCheckNullDim) { TensorFactory tf; // clang-format off diff --git a/kernels/test/op_as_strided_copy_test.cpp b/kernels/test/op_as_strided_copy_test.cpp index e0e78be194b..9d9fadb7db1 100644 --- a/kernels/test/op_as_strided_copy_test.cpp +++ b/kernels/test/op_as_strided_copy_test.cpp @@ -25,166 +25,180 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; -Tensor& op_as_strided_copy_out( - const Tensor& self, - ArrayRef size, - ArrayRef stride, - optional storage_offset, - Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::as_strided_copy_outf( - context, self, size, stride, storage_offset, out); -} - -// Common testing for eq operator -template -void test_detach_copy_out() { - TensorFactory tf; - const std::vector in_sizes = {3, 3}; - const std::vector out_sizes = {2, 2, 2}; - - Tensor in = tf.make(in_sizes, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - Tensor out = tf.zeros(out_sizes); - - // Valid input should give the expected output - optional storage_offset; - int64_t sizes[3] = {2, 2, 2}; - int64_t stride[3] = {1, 2, 3}; - op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out); - EXPECT_TENSOR_EQ(out, tf.make(out_sizes, {1, 4, 3, 6, 2, 5, 4, 7})); - - // With storage offset - op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - /*storage_offset=*/2, - out); - EXPECT_TENSOR_EQ(out, tf.make(out_sizes, {3, 6, 5, 8, 4, 7, 6, 9})); -} - -template <> -void test_detach_copy_out() { - TensorFactory tf; - const std::vector in_sizes = {3, 3}; - const std::vector out_sizes = {2, 2, 2}; - Tensor in = tf.make( - in_sizes, {false, true, false, true, false, true, false, true, false}); - Tensor out = tf.zeros(out_sizes); +class OpAsStridedCopyOutTest : public OperatorTest { + protected: + Tensor& op_as_strided_copy_out( + const Tensor& self, + ArrayRef size, + ArrayRef stride, + optional storage_offset, + Tensor& out) { + return torch::executor::aten::as_strided_copy_outf( + context_, self, size, stride, storage_offset, out); + } - // Valid input should give the expected output - optional storage_offset = 2; - int64_t sizes[3] = {2, 2, 2}; - int64_t stride[3] = {1, 2, 3}; - op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out); - EXPECT_TENSOR_EQ( - out, - tf.make(out_sizes, {false, true, false, true, true, false, true, false})); -} + // Common testing for eq operator + template + void test_detach_copy_out() { + TensorFactory tf; + const std::vector in_sizes = {3, 3}; + const std::vector out_sizes = {2, 2, 2}; + + Tensor in = tf.make(in_sizes, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + Tensor out = tf.zeros(out_sizes); + + // Valid input should give the expected output + optional storage_offset; + int64_t sizes[3] = {2, 2, 2}; + int64_t stride[3] = {1, 2, 3}; + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out); + EXPECT_TENSOR_EQ(out, tf.make(out_sizes, {1, 4, 3, 6, 2, 5, 4, 7})); + + // With storage offset + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + /*storage_offset=*/2, + out); + EXPECT_TENSOR_EQ(out, tf.make(out_sizes, {3, 6, 5, 8, 4, 7, 6, 9})); + } -template <> -void test_detach_copy_out() { - TensorFactory tf; - const std::vector in_sizes = {3, 3}; - const std::vector out_sizes = {2, 2, 2}; + template <> + void test_detach_copy_out() { + TensorFactory tf; + const std::vector in_sizes = {3, 3}; + const std::vector out_sizes = {2, 2, 2}; + Tensor in = tf.make( + in_sizes, {false, true, false, true, false, true, false, true, false}); + Tensor out = tf.zeros(out_sizes); + + // Valid input should give the expected output + optional storage_offset = 2; + int64_t sizes[3] = {2, 2, 2}; + int64_t stride[3] = {1, 2, 3}; + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out); + EXPECT_TENSOR_EQ( + out, + tf.make( + out_sizes, {false, true, false, true, true, false, true, false})); + } - Tensor in = tf.make( - in_sizes, {3.14, 2.33, 42, INFINITY, -INFINITY, NAN, -3.14, -2.33, -42}); - Tensor out = tf.zeros(out_sizes); + template <> + void test_detach_copy_out() { + TensorFactory tf; + const std::vector in_sizes = {3, 3}; + const std::vector out_sizes = {2, 2, 2}; + + Tensor in = tf.make( + in_sizes, + {3.14, 2.33, 42, INFINITY, -INFINITY, NAN, -3.14, -2.33, -42}); + Tensor out = tf.zeros(out_sizes); + + // Valid input should give the expected output + optional storage_offset = 2; + int64_t sizes[3] = {2, 2, 2}; + int64_t stride[3] = {1, 2, 3}; + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out); + EXPECT_TENSOR_CLOSE( + out, + tf.make( + out_sizes, + {42.0, NAN, -INFINITY, 2.33, INFINITY, -3.14, NAN, -42.0})); + } - // Valid input should give the expected output - optional storage_offset = 2; - int64_t sizes[3] = {2, 2, 2}; - int64_t stride[3] = {1, 2, 3}; - op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out); - EXPECT_TENSOR_CLOSE( - out, - tf.make( - out_sizes, - {42.0, NAN, -INFINITY, 2.33, INFINITY, -3.14, NAN, -42.0})); -} + template + void test_as_strided_copy_out_invalid_parameters() { + TensorFactory tf; + + const std::vector in_sizes = {3, 3}; + const std::vector out_sizes = {2, 2, 2}; + + Tensor in = tf.ones(in_sizes); + Tensor out = tf.zeros(out_sizes); + optional storage_offset; + int64_t sizes[3] = {2, 2, 2}; + int64_t stride[3] = {1, 2, 3}; + + // Mismatch strides and shape should die + int64_t stride_short[2] = {1, 2}; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride_short, 2}, + storage_offset, + out)); + + // Negative strides should die + int64_t stride_negative[3] = {1, 2, -1}; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride_negative, 3}, + storage_offset, + out)); + + // Mismatch output tensor shape and size should die + int64_t size_invalid[3] = {2, 2, 1}; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{size_invalid, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out)); + + // Invalid storage offset should die + storage_offset = -1; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out)); + + // Out of bound storage access of `in` should die + storage_offset = 3; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out)); + } +}; -TEST(OpAsStridedCopyOutKernelTest, AllScalarInputOutputSupport) { +TEST_F(OpAsStridedCopyOutTest, AllScalarInputOutputSupport) { #define TEST_ENTRY(ctype, dtype) test_detach_copy_out(); ET_FORALL_INT_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -template -void test_as_strided_copy_out_invalid_parameters() { - TensorFactory tf; - - const std::vector in_sizes = {3, 3}; - const std::vector out_sizes = {2, 2, 2}; - - Tensor in = tf.ones(in_sizes); - Tensor out = tf.zeros(out_sizes); - optional storage_offset; - int64_t sizes[3] = {2, 2, 2}; - int64_t stride[3] = {1, 2, 3}; - - // Mismatch strides and shape should die - int64_t stride_short[2] = {1, 2}; - ET_EXPECT_KERNEL_FAILURE(op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride_short, 2}, - storage_offset, - out)); - - // Negative strides should die - int64_t stride_negative[3] = {1, 2, -1}; - ET_EXPECT_KERNEL_FAILURE(op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride_negative, 3}, - storage_offset, - out)); - - // Mismatch output tensor shape and size should die - int64_t size_invalid[3] = {2, 2, 1}; - ET_EXPECT_KERNEL_FAILURE(op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{size_invalid, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out)); - - // Invalid storage offset should die - storage_offset = -1; - ET_EXPECT_KERNEL_FAILURE(op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out)); - - // Out of bound storage access of `in` should die - storage_offset = 3; - ET_EXPECT_KERNEL_FAILURE(op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out)); -} - -TEST(OpAsStridedCopyOutKernelTest, InvalidParametersDies) { +TEST_F(OpAsStridedCopyOutTest, InvalidParametersDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle invalid parameter"; } @@ -194,7 +208,7 @@ TEST(OpAsStridedCopyOutKernelTest, InvalidParametersDies) { #undef TEST_ENTRY } -TEST(OpAsStridedCopyOutKernelTest, MismatchedInputDtypesDies) { +TEST_F(OpAsStridedCopyOutTest, MismatchedInputDtypesDies) { TensorFactory tf_byte; TensorFactory tf_char; const std::vector in_sizes = {3, 3}; @@ -206,12 +220,14 @@ TEST(OpAsStridedCopyOutKernelTest, MismatchedInputDtypesDies) { int64_t sizes[3] = {2, 2, 2}; int64_t stride[3] = {1, 2, 3}; - ET_EXPECT_KERNEL_FAILURE(op_as_strided_copy_out( - /*self=*/in, - /*size=*/ArrayRef{sizes, 3}, - /*stride=*/ArrayRef{stride, 3}, - storage_offset, - out)); + ET_EXPECT_KERNEL_FAILURE( + context_, + op_as_strided_copy_out( + /*self=*/in, + /*size=*/ArrayRef{sizes, 3}, + /*stride=*/ArrayRef{stride, 3}, + storage_offset, + out)); } /* %python @@ -229,7 +245,7 @@ opt_extra_params = "size, stride, storage_offset," dtype = "ScalarType::Float" check = "EXPECT_TENSOR_EQ" */ -TEST(OpAsStridedCopyOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { +TEST_F(OpAsStridedCopyOutTest, DynamicShapeUpperBoundSameAsExpected) { /* %python out_args = "{2, 2, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND" %rewrite(unary_op) */ @@ -270,7 +286,7 @@ TEST(OpAsStridedCopyOutKernelTest, DynamicShapeUpperBoundSameAsExpected) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpAsStridedCopyOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { +TEST_F(OpAsStridedCopyOutTest, DynamicShapeUpperBoundLargerThanExpected) { /* %python out_args = "{5, 5, 5}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND" %rewrite(unary_op) */ @@ -311,7 +327,7 @@ TEST(OpAsStridedCopyOutKernelTest, DynamicShapeUpperBoundLargerThanExpected) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpAsStridedCopyOutKernelTest, DynamicShapeUnbound) { +TEST_F(OpAsStridedCopyOutTest, DynamicShapeUnbound) { if (!torch::executor::testing::SupportedFeatures::get()->output_resize) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } diff --git a/kernels/test/op_asin_test.cpp b/kernels/test/op_asin_test.cpp index 548868f9b8d..ae0af71d2de 100644 --- a/kernels/test/op_asin_test.cpp +++ b/kernels/test/op_asin_test.cpp @@ -21,12 +21,49 @@ using exec_aten::Tensor; using exec_aten::TensorShapeDynamism; using torch::executor::testing::TensorFactory; -Tensor& op_asin_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::asin_outf(context, self, out); -} +class OpAsinOutTest : public OperatorTest { + protected: + Tensor& op_asin_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::asin_outf(context_, self, out); + } + + // Common testing for asin operator and all kinds of supported input types + template + void test_floating_point_asin_out( + const std::vector& out_shape = {1, 6}, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + TensorFactory tf_in; + TensorFactory tf_out; + + // Destination for the asin operator. + Tensor out = tf_out.zeros(out_shape, dynamism); + + // clang-format off + op_asin_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); + + // Check that it matches (or close to) the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf_out.make({1, 6}, { 0.000000, 1.570796, NAN, NAN, NAN, NAN })); + // clang-format on + } + + // Unhandled output dtypes. + template + void test_asin_invalid_output_dtype_dies() { + TensorFactory tf; + TensorFactory tf_out; + + const std::vector sizes = {2, 5}; + + Tensor in = tf.ones(sizes); + Tensor out = tf_out.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_asin_out(in, out)); + } +}; -TEST(OpAsinOutKernelTest, HandleBoolInput) { +TEST_F(OpAsinOutTest, HandleBoolInput) { TensorFactory tf_bool; TensorFactory tf_float; @@ -39,28 +76,7 @@ TEST(OpAsinOutKernelTest, HandleBoolInput) { EXPECT_TENSOR_CLOSE(op_asin_out(a, out), res); } -// Common testing for asin operator and all kinds of supported input types -template -void test_floating_point_asin_out( - const std::vector& out_shape = {1, 6}, - TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { - TensorFactory tf_in; - TensorFactory tf_out; - - // Destination for the asin operator. - Tensor out = tf_out.zeros(out_shape, dynamism); - - // clang-format off - op_asin_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); - - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make({1, 6}, { 0.000000, 1.570796, NAN, NAN, NAN, NAN })); - // clang-format on -} - -TEST(OpAsinOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputHalfOutputStaticDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -70,21 +86,21 @@ TEST(OpAsinOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputFloatOutputStaticDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputFloatOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asin_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputDoubleOutputStaticDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputDoubleOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asin_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputHalfOutputBoundDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -95,7 +111,7 @@ TEST(OpAsinOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputFloatOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asin_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -103,7 +119,7 @@ TEST(OpAsinOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputDoubleOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asin_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -111,7 +127,7 @@ TEST(OpAsinOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputFloatOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -122,7 +138,7 @@ TEST(OpAsinOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { +TEST_F(OpAsinOutTest, AllRealInputDoubleOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -133,21 +149,7 @@ TEST(OpAsinOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -// Unhandled output dtypes. -template -void test_asin_invalid_output_dtype_dies() { - TensorFactory tf; - TensorFactory tf_out; - - const std::vector sizes = {2, 5}; - - Tensor in = tf.ones(sizes); - Tensor out = tf_out.zeros(sizes); - - ET_EXPECT_KERNEL_FAILURE(op_asin_out(in, out)); -} - -TEST(OpAsinOutKernelTest, AllNonFloatOutputDTypeDies) { +TEST_F(OpAsinOutTest, AllNonFloatOutputDTypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_asin_invalid_output_dtype_dies(); ET_FORALL_INT_TYPES(TEST_ENTRY); @@ -155,7 +157,7 @@ TEST(OpAsinOutKernelTest, AllNonFloatOutputDTypeDies) { } // Mismatched shape tests. -TEST(OpAsinOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAsinOutTest, MismatchedInputShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched input shapes"; } @@ -164,5 +166,5 @@ TEST(OpAsinOutKernelTest, MismatchedInputShapesDies) { Tensor a = tf.ones(/*sizes=*/{4}); Tensor out = tf.ones(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(op_asin_out(a, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_asin_out(a, out)); } diff --git a/kernels/test/op_asinh_test.cpp b/kernels/test/op_asinh_test.cpp index 48a622e0281..cd887404b75 100644 --- a/kernels/test/op_asinh_test.cpp +++ b/kernels/test/op_asinh_test.cpp @@ -21,12 +21,49 @@ using exec_aten::Tensor; using exec_aten::TensorShapeDynamism; using torch::executor::testing::TensorFactory; -Tensor& op_asinh_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::asinh_outf(context, self, out); -} +class OpAsinhOutTest : public OperatorTest { + protected: + Tensor& op_asinh_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::asinh_outf(context_, self, out); + } + + // Common testing for asinh operator and all kinds of supported input types + template + void test_floating_point_asinh_out( + const std::vector& out_shape = {1, 6}, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + TensorFactory tf_in; + TensorFactory tf_out; + + // Destination for the asinh operator. + Tensor out = tf_out.zeros(out_shape, dynamism); + + // clang-format off + op_asinh_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); + + // Check that it matches (or close to) the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf_out.make({1, 6}, { 0.000000, 0.881374, 1.818447, 2.312438, 2.998223, 5.298342 })); + // clang-format on + } + + // Unhandled output dtypes. + template + void test_asinh_invalid_output_dtype_dies() { + TensorFactory tf; + TensorFactory tf_out; + + const std::vector sizes = {2, 5}; + + Tensor in = tf.ones(sizes); + Tensor out = tf_out.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_asinh_out(in, out)); + } +}; -TEST(OpAsinhOutKernelTest, HandleBoolInput) { +TEST_F(OpAsinhOutTest, HandleBoolInput) { TensorFactory tf_bool; TensorFactory tf_float; @@ -39,28 +76,7 @@ TEST(OpAsinhOutKernelTest, HandleBoolInput) { EXPECT_TENSOR_CLOSE(op_asinh_out(a, out), res); } -// Common testing for asinh operator and all kinds of supported input types -template -void test_floating_point_asinh_out( - const std::vector& out_shape = {1, 6}, - TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { - TensorFactory tf_in; - TensorFactory tf_out; - - // Destination for the asinh operator. - Tensor out = tf_out.zeros(out_shape, dynamism); - - // clang-format off - op_asinh_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); - - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make({1, 6}, { 0.000000, 0.881374, 1.818447, 2.312438, 2.998223, 5.298342 })); - // clang-format on -} - -TEST(OpAsinhOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputHalfOutputStaticDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -70,21 +86,21 @@ TEST(OpAsinhOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputFloatOutputStaticDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputFloatOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asinh_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputDoubleOutputStaticDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputDoubleOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asinh_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputHalfOutputBoundDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -95,7 +111,7 @@ TEST(OpAsinhOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputFloatOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asinh_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -103,7 +119,7 @@ TEST(OpAsinhOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputDoubleOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_asinh_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -111,7 +127,7 @@ TEST(OpAsinhOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputFloatOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -122,7 +138,7 @@ TEST(OpAsinhOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAsinhOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { +TEST_F(OpAsinhOutTest, AllRealInputDoubleOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -133,21 +149,7 @@ TEST(OpAsinhOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -// Unhandled output dtypes. -template -void test_asinh_invalid_output_dtype_dies() { - TensorFactory tf; - TensorFactory tf_out; - - const std::vector sizes = {2, 5}; - - Tensor in = tf.ones(sizes); - Tensor out = tf_out.zeros(sizes); - - ET_EXPECT_KERNEL_FAILURE(op_asinh_out(in, out)); -} - -TEST(OpAsinhOutKernelTest, AllNonFloatOutputDTypeDies) { +TEST_F(OpAsinhOutTest, AllNonFloatOutputDTypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_asinh_invalid_output_dtype_dies(); ET_FORALL_INT_TYPES(TEST_ENTRY); @@ -155,7 +157,7 @@ TEST(OpAsinhOutKernelTest, AllNonFloatOutputDTypeDies) { } // Mismatched shape tests. -TEST(OpAsinhOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAsinhOutTest, MismatchedInputShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched input shapes"; } @@ -164,5 +166,5 @@ TEST(OpAsinhOutKernelTest, MismatchedInputShapesDies) { Tensor a = tf.ones(/*sizes=*/{4}); Tensor out = tf.ones(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(op_asinh_out(a, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_asinh_out(a, out)); } diff --git a/kernels/test/op_atan_test.cpp b/kernels/test/op_atan_test.cpp index 821e52b9e14..6258819432f 100644 --- a/kernels/test/op_atan_test.cpp +++ b/kernels/test/op_atan_test.cpp @@ -21,12 +21,49 @@ using exec_aten::Tensor; using exec_aten::TensorShapeDynamism; using torch::executor::testing::TensorFactory; -Tensor& op_atan_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::atan_outf(context, self, out); -} +class OpAtanOutTest : public OperatorTest { + protected: + Tensor& op_atan_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::atan_outf(context_, self, out); + } + + // Common testing for atan operator and all kinds of supported input types + template + void test_floating_point_atan_out( + const std::vector& out_shape = {1, 6}, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + TensorFactory tf_in; + TensorFactory tf_out; + + // Destination for the atan operator. + Tensor out = tf_out.zeros(out_shape, dynamism); + + // clang-format off + op_atan_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); + + // Check that it matches (or close to) the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf_out.make({1, 6}, { 0.000000, 0.785398, 1.249046, 1.373401, 1.471128, 1.560797 })); + // clang-format on + } + + // Unhandled output dtypes. + template + void test_atan_invalid_output_dtype_dies() { + TensorFactory tf; + TensorFactory tf_out; + + const std::vector sizes = {2, 5}; + + Tensor in = tf.ones(sizes); + Tensor out = tf_out.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_atan_out(in, out)); + } +}; -TEST(OpAtanOutKernelTest, HandleBoolInput) { +TEST_F(OpAtanOutTest, HandleBoolInput) { TensorFactory tf_bool; TensorFactory tf_float; @@ -39,28 +76,7 @@ TEST(OpAtanOutKernelTest, HandleBoolInput) { EXPECT_TENSOR_CLOSE(op_atan_out(a, out), res); } -// Common testing for atan operator and all kinds of supported input types -template -void test_floating_point_atan_out( - const std::vector& out_shape = {1, 6}, - TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { - TensorFactory tf_in; - TensorFactory tf_out; - - // Destination for the atan operator. - Tensor out = tf_out.zeros(out_shape, dynamism); - - // clang-format off - op_atan_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); - - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make({1, 6}, { 0.000000, 0.785398, 1.249046, 1.373401, 1.471128, 1.560797 })); - // clang-format on -} - -TEST(OpAtanOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputHalfOutputStaticDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -70,21 +86,21 @@ TEST(OpAtanOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputFloatOutputStaticDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputFloatOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atan_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputDoubleOutputStaticDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputDoubleOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atan_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputHalfOutputBoundDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -95,7 +111,7 @@ TEST(OpAtanOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputFloatOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atan_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -103,7 +119,7 @@ TEST(OpAtanOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputDoubleOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atan_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -111,7 +127,7 @@ TEST(OpAtanOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputFloatOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -122,7 +138,7 @@ TEST(OpAtanOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { +TEST_F(OpAtanOutTest, AllRealInputDoubleOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -133,21 +149,7 @@ TEST(OpAtanOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -// Unhandled output dtypes. -template -void test_atan_invalid_output_dtype_dies() { - TensorFactory tf; - TensorFactory tf_out; - - const std::vector sizes = {2, 5}; - - Tensor in = tf.ones(sizes); - Tensor out = tf_out.zeros(sizes); - - ET_EXPECT_KERNEL_FAILURE(op_atan_out(in, out)); -} - -TEST(OpAtanOutKernelTest, AllNonFloatOutputDTypeDies) { +TEST_F(OpAtanOutTest, AllNonFloatOutputDTypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_atan_invalid_output_dtype_dies(); ET_FORALL_INT_TYPES(TEST_ENTRY); @@ -155,7 +157,7 @@ TEST(OpAtanOutKernelTest, AllNonFloatOutputDTypeDies) { } // Mismatched shape tests. -TEST(OpAtanOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAtanOutTest, MismatchedInputShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched input shapes"; } @@ -164,5 +166,5 @@ TEST(OpAtanOutKernelTest, MismatchedInputShapesDies) { Tensor a = tf.ones(/*sizes=*/{4}); Tensor out = tf.ones(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(op_atan_out(a, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_atan_out(a, out)); } diff --git a/kernels/test/op_atanh_test.cpp b/kernels/test/op_atanh_test.cpp index 5d1156ca550..88f02603c85 100644 --- a/kernels/test/op_atanh_test.cpp +++ b/kernels/test/op_atanh_test.cpp @@ -21,12 +21,49 @@ using exec_aten::Tensor; using exec_aten::TensorShapeDynamism; using torch::executor::testing::TensorFactory; -Tensor& op_atanh_out(const Tensor& self, Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::atanh_outf(context, self, out); -} +class OpAtanhOutTest : public OperatorTest { + protected: + Tensor& op_atanh_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::atanh_outf(context_, self, out); + } + + // Common testing for atanh operator and all kinds of supported input types + template + void test_floating_point_atanh_out( + const std::vector& out_shape = {1, 6}, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { + TensorFactory tf_in; + TensorFactory tf_out; + + // Destination for the atanh operator. + Tensor out = tf_out.zeros(out_shape, dynamism); + + // clang-format off + op_atanh_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); + + // Check that it matches (or close to) the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf_out.make({1, 6}, { 0.0, std::numeric_limits::infinity(), NAN, NAN, NAN, NAN })); + // clang-format on + } + + // Unhandled output dtypes. + template + void test_atanh_invalid_output_dtype_dies() { + TensorFactory tf; + TensorFactory tf_out; + + const std::vector sizes = {2, 5}; + + Tensor in = tf.ones(sizes); + Tensor out = tf_out.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_atanh_out(in, out)); + } +}; -TEST(OpAtanhOutKernelTest, HandleBoolInput) { +TEST_F(OpAtanhOutTest, HandleBoolInput) { TensorFactory tf_bool; TensorFactory tf_float; @@ -39,28 +76,7 @@ TEST(OpAtanhOutKernelTest, HandleBoolInput) { EXPECT_TENSOR_CLOSE(op_atanh_out(a, out), res); } -// Common testing for atanh operator and all kinds of supported input types -template -void test_floating_point_atanh_out( - const std::vector& out_shape = {1, 6}, - TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC) { - TensorFactory tf_in; - TensorFactory tf_out; - - // Destination for the atanh operator. - Tensor out = tf_out.zeros(out_shape, dynamism); - - // clang-format off - op_atanh_out(tf_in.make({1, 6}, { 0, 1, 3, 5, 10, 100 }), out); - - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make({1, 6}, { 0.0, std::numeric_limits::infinity(), NAN, NAN, NAN, NAN })); - // clang-format on -} - -TEST(OpAtanhOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputHalfOutputStaticDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -70,21 +86,21 @@ TEST(OpAtanhOutKernelTest, AllRealInputHalfOutputStaticDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputFloatOutputStaticDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputFloatOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atanh_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputDoubleOutputStaticDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputDoubleOutputStaticDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atanh_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputHalfOutputBoundDynamismSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; } @@ -95,7 +111,7 @@ TEST(OpAtanhOutKernelTest, AllRealInputHalfOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputFloatOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atanh_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -103,7 +119,7 @@ TEST(OpAtanhOutKernelTest, AllRealInputFloatOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputDoubleOutputBoundDynamismSupport) { #define TEST_ENTRY(ctype, dtype) \ test_floating_point_atanh_out( \ {10, 10}, TensorShapeDynamism::DYNAMIC_BOUND); @@ -111,7 +127,7 @@ TEST(OpAtanhOutKernelTest, AllRealInputDoubleOutputBoundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputFloatOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -122,7 +138,7 @@ TEST(OpAtanhOutKernelTest, AllRealInputFloatOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -TEST(OpAtanhOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { +TEST_F(OpAtanhOutTest, AllRealInputDoubleOutputUnboundDynamismSupport) { if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "Dynamic shape unbound not supported"; } @@ -133,21 +149,7 @@ TEST(OpAtanhOutKernelTest, AllRealInputDoubleOutputUnboundDynamismSupport) { #undef TEST_ENTRY } -// Unhandled output dtypes. -template -void test_atanh_invalid_output_dtype_dies() { - TensorFactory tf; - TensorFactory tf_out; - - const std::vector sizes = {2, 5}; - - Tensor in = tf.ones(sizes); - Tensor out = tf_out.zeros(sizes); - - ET_EXPECT_KERNEL_FAILURE(op_atanh_out(in, out)); -} - -TEST(OpAtanhOutKernelTest, AllNonFloatOutputDTypeDies) { +TEST_F(OpAtanhOutTest, AllNonFloatOutputDTypeDies) { #define TEST_ENTRY(ctype, dtype) \ test_atanh_invalid_output_dtype_dies(); ET_FORALL_INT_TYPES(TEST_ENTRY); @@ -155,7 +157,7 @@ TEST(OpAtanhOutKernelTest, AllNonFloatOutputDTypeDies) { } // Mismatched shape tests. -TEST(OpAtanhOutKernelTest, MismatchedInputShapesDies) { +TEST_F(OpAtanhOutTest, MismatchedInputShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched input shapes"; } @@ -164,5 +166,5 @@ TEST(OpAtanhOutKernelTest, MismatchedInputShapesDies) { Tensor a = tf.ones(/*sizes=*/{4}); Tensor out = tf.ones(/*sizes=*/{2, 2}); - ET_EXPECT_KERNEL_FAILURE(op_atanh_out(a, out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_atanh_out(a, out)); } diff --git a/kernels/test/op_avg_pool2d_test.cpp b/kernels/test/op_avg_pool2d_test.cpp index 2bb3389ba14..90838330fa1 100644 --- a/kernels/test/op_avg_pool2d_test.cpp +++ b/kernels/test/op_avg_pool2d_test.cpp @@ -17,29 +17,31 @@ using namespace ::testing; -exec_aten::Tensor& op_avg_pool2d_out( - const exec_aten::Tensor& self, - exec_aten::ArrayRef kernel_size, - exec_aten::ArrayRef stride, - exec_aten::ArrayRef padding, - bool ceil_mode, - bool count_include_pad, - exec_aten::optional divisor_override, - exec_aten::Tensor& out) { - exec_aten::RuntimeContext context{}; - return torch::executor::aten::avg_pool2d_outf( - context, - self, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override, - out); -} +class OpAvgPool2DOutTest : public OperatorTest { + protected: + exec_aten::Tensor& op_avg_pool2d_out( + const exec_aten::Tensor& self, + exec_aten::ArrayRef kernel_size, + exec_aten::ArrayRef stride, + exec_aten::ArrayRef padding, + bool ceil_mode, + bool count_include_pad, + exec_aten::optional divisor_override, + exec_aten::Tensor& out) { + return torch::executor::aten::avg_pool2d_outf( + context_, + self, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + out); + } +}; -TEST(OpAvgPool2DOutTest, SanityCheck4D) { +TEST_F(OpAvgPool2DOutTest, SanityCheck4D) { torch::executor::testing::TensorFactory tfFloat; exec_aten::Tensor self = tfFloat.make( @@ -191,7 +193,7 @@ TEST(OpAvgPool2DOutTest, SanityCheck4D) { EXPECT_TENSOR_CLOSE(out, out_expected); } -TEST(OpAvgPool2DOutTest, SanityCheck4DDivisorOverride) { +TEST_F(OpAvgPool2DOutTest, SanityCheck4DDivisorOverride) { torch::executor::testing::TensorFactory tfFloat; exec_aten::Tensor self = tfFloat.make( @@ -344,7 +346,7 @@ TEST(OpAvgPool2DOutTest, SanityCheck4DDivisorOverride) { EXPECT_TENSOR_CLOSE(out, out_expected); } -TEST(OpAvgPool2DOutTest, SanityCheck4DCeilModeNoIncludePadding) { +TEST_F(OpAvgPool2DOutTest, SanityCheck4DCeilModeNoIncludePadding) { torch::executor::testing::TensorFactory tfFloat; exec_aten::Tensor self = tfFloat.make( diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 789179c4cad..f110ec007b8 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -50,10 +50,12 @@ def define_common_targets(is_fbcode = False): fbcode_exported_deps = [ "//common/init:init", "//common/gtest:gtest", + "//executorch/runtime/kernel:kernel_includes", ], xplat_exported_deps = [ "//xplat/folly:init_init", "//third-party/googletest:gtest_main", + "//executorch/runtime/kernel:kernel_includes", ], ) diff --git a/kernels/test/util.bzl b/kernels/test/util.bzl index 0efeb497740..7a7da46d07a 100644 --- a/kernels/test/util.bzl +++ b/kernels/test/util.bzl @@ -51,6 +51,7 @@ def op_test(name, deps = [], aten_compatible = True, kernel_name = "portable", u deps = [ "//executorch/runtime/core/exec_aten:lib" + aten_suffix, "//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix, + "//executorch/runtime/kernel:kernel_includes" + aten_suffix, "//executorch/kernels/test:test_util" + aten_suffix, ] + generated_lib_and_op_deps + deps, ) @@ -84,6 +85,7 @@ def generated_op_test(name, op_impl_target, generated_lib_headers_target, suppor deps = [ "//executorch/runtime/core/exec_aten:lib", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/test:test_util", op_impl_target, generated_lib_headers_target, diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index f831f826f54..c1917d1dd9e 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -360,6 +360,26 @@ inline bool isFloatingType(exec_aten::ScalarType t) { t == exec_aten::ScalarType::Half || t == exec_aten::ScalarType::BFloat16); } +inline bool isRealType(exec_aten::ScalarType t) { + return ( + t == exec_aten::ScalarType::Byte || t == exec_aten::ScalarType::Char || + t == exec_aten::ScalarType::Short || t == exec_aten::ScalarType::Int || + t == exec_aten::ScalarType::Long || t == exec_aten::ScalarType::Float || + t == exec_aten::ScalarType::Double); +} + +inline bool isRealHType(exec_aten::ScalarType t) { + return ( + t == exec_aten::ScalarType::Byte || t == exec_aten::ScalarType::Char || + t == exec_aten::ScalarType::Short || t == exec_aten::ScalarType::Int || + t == exec_aten::ScalarType::Long || t == exec_aten::ScalarType::Float || + t == exec_aten::ScalarType::Double || t == exec_aten::ScalarType::Half); +} + +inline bool isRealHBType(exec_aten::ScalarType t) { + return (isRealHType(t) || t == exec_aten::ScalarType::Bool); +} + inline bool isComplexType(exec_aten::ScalarType t) { return ( t == exec_aten::ScalarType::ComplexHalf || diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index f7a4a8d2a99..c5c663e28c5 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -357,9 +357,6 @@ * If `cond` is false, log `cond` and return from the kernel with a failure * state set. * - * TODO(ssjia): add context.fail(torch.executor::Error::error); before exit - * TODO(ssjia): replace runtime_abort() with return retval - * * @param[in] context the runtime context * @param[in] cond the condition to check * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`) @@ -369,7 +366,8 @@ do { \ if (!(cond)) { \ ET_LOG(Error, "Check failed (%s): ", #cond); \ - torch::executor::runtime_abort(); \ + context.fail(torch::executor::Error::error); \ + return retval; \ } \ } while (false) @@ -377,9 +375,6 @@ * If `cond` is false, log `message` and return from the kernel with a failure * state set. * - * TODO(ssjia): add context.fail(torch.executor::Error::error); before exit - * TODO(ssjia): replace runtime_abort() with return retval - * * @param[in] context the runtime context * @param[in] cond the condition to check * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`) @@ -389,7 +384,8 @@ do { \ if (!(cond)) { \ ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \ - torch::executor::runtime_abort(); \ + context.fail(torch::executor::Error::error); \ + return retval; \ } \ } while (false) @@ -491,6 +487,33 @@ inline bool tensor_is_floating_type(exec_aten::Tensor t) { return true; } +inline bool tensor_is_real_type(exec_aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isRealType(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_realh_type(exec_aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isRealHType(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_realhb_type(exec_aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isRealHBType(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + inline bool tensor_is_complex_type(exec_aten::Tensor t) { ET_LOG_MSG_AND_RETURN_IF_FALSE( torch::executor::isComplexType(t.scalar_type()),