diff --git a/kernels/test/op_var_test.cpp b/kernels/test/op_var_test.cpp index 5011f1a812f..fbfd16f1b23 100644 --- a/kernels/test/op_var_test.cpp +++ b/kernels/test/op_var_test.cpp @@ -328,6 +328,9 @@ TEST_F(OpVarOutTest, InvalidDTypeDies) { } TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen supports fewer dtypes"; + } // Use a two layer switch to hanldle each possible data pair #define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ test_var_out_dtype(); @@ -340,6 +343,22 @@ TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) { #undef TEST_KERNEL } +TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses_Aten) { + if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen-specific variant of test case"; + } + // Use a two layer switch to hanldle each possible data pair +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_var_out_dtype(); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#undef TEST_KERNEL +} + TEST_F(OpVarOutTest, InfinityAndNANTest) { TensorFactory tf_float; // clang-format off